Skip to content

Commit

Permalink
feat: always read the authorized keys file (#88)
Browse files Browse the repository at this point in the history
Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>
  • Loading branch information
caarlos0 committed Nov 15, 2022
1 parent 1da25a2 commit d2ae592
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 53 deletions.
71 changes: 32 additions & 39 deletions options.go
Expand Up @@ -4,8 +4,8 @@ import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -72,17 +72,13 @@ func WithHostKeyPEM(pem []byte) ssh.Option {
// WithAuthorizedKeys allows to use a SSH authorized_keys file to allowlist users.
func WithAuthorizedKeys(path string) ssh.Option {
return func(s *ssh.Server) error {
keys, err := parseAuthorizedKeys(path)
if err != nil {
if _, err := os.Stat(path); err != nil {
return err
}
return WithPublicKeyAuth(func(_ ssh.Context, key ssh.PublicKey) bool {
for _, upk := range keys {
if ssh.KeysEqual(upk, key) {
return true
}
}
return false
return isAuthorized(path, func(k ssh.PublicKey) bool {
return ssh.KeysEqual(key, k)
})
})(s)
}
}
Expand All @@ -92,50 +88,43 @@ func WithAuthorizedKeys(path string) ssh.Option {
// Analogous to the TrustedUserCAKeys OpenSSH option.
func WithTrustedUserCAKeys(path string) ssh.Option {
return func(s *ssh.Server) error {
cas, err := parseAuthorizedKeys(path)
if err != nil {
if _, err := os.Stat(path); err != nil {
return err
}

return WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
cert, ok := key.(*gossh.Certificate)
if !ok {
// not a certificate...
return false
}

checker := &gossh.CertChecker{
IsUserAuthority: func(auth gossh.PublicKey) bool {
for _, ca := range cas {
if bytes.Equal(auth.Marshal(), ca.Marshal()) {
// its a cert signed by one of the CAs
return true
}
}
// it is a cert, but signed by another CA
return false
},
}
return isAuthorized(path, func(k ssh.PublicKey) bool {
checker := &gossh.CertChecker{
IsUserAuthority: func(auth gossh.PublicKey) bool {
// its a cert signed by one of the CAs
return bytes.Equal(auth.Marshal(), k.Marshal())
},
}

if !checker.IsUserAuthority(cert.SignatureKey) {
return false
}
if !checker.IsUserAuthority(cert.SignatureKey) {
return false
}

if err := checker.CheckCert(ctx.User(), cert); err != nil {
return false
}
if err := checker.CheckCert(ctx.User(), cert); err != nil {
return false
}

return true
return true
})
})(s)
}
}

func parseAuthorizedKeys(path string) ([]ssh.PublicKey, error) {
var keys []ssh.PublicKey

func isAuthorized(path string, checker func(k ssh.PublicKey) bool) bool {
f, err := os.Open(path)
if err != nil {
return keys, fmt.Errorf("failed to parse %q: %w", path, err)
log.Printf("failed to parse %q: %s", path, err)
return false
}
defer f.Close() // nolint: errcheck

Expand All @@ -146,7 +135,8 @@ func parseAuthorizedKeys(path string) ([]ssh.PublicKey, error) {
if errors.Is(err, io.EOF) {
break
}
return keys, fmt.Errorf("failed to parse %q: %w", path, err)
log.Printf("failed to parse %q: %s", path, err)
return false
}
if strings.TrimSpace(string(line)) == "" {
continue
Expand All @@ -156,11 +146,14 @@ func parseAuthorizedKeys(path string) ([]ssh.PublicKey, error) {
}
upk, _, _, _, err := ssh.ParseAuthorizedKey(line)
if err != nil {
return keys, fmt.Errorf("failed to parse %q: %w", path, err)
log.Printf("failed to parse %q: %s", path, err)
return false
}
if checker(upk) {
return true
}
keys = append(keys, upk)
}
return keys, nil
return false
}

// WithPublicKeyAuth returns an ssh.Option that sets the public key auth handler.
Expand Down
28 changes: 14 additions & 14 deletions options_test.go
Expand Up @@ -25,23 +25,17 @@ func TestWithMaxTimeout(t *testing.T) {
requireEqual(t, time.Second, s.MaxTimeout)
}

func TestParseAuthorizedKeys(t *testing.T) {
func TestIsAuthorized(t *testing.T) {
t.Run("valid", func(t *testing.T) {
keys, err := parseAuthorizedKeys("testdata/authorized_keys")
requireNoError(t, err)
requireEqual(t, 6, len(keys))
requireEqual(t, true, isAuthorized("testdata/authorized_keys", func(k ssh.PublicKey) bool { return true }))
})

t.Run("invalid", func(t *testing.T) {
keys, err := parseAuthorizedKeys("testdata/invalid_authorized_keys")
requireEqual(t, `failed to parse "testdata/invalid_authorized_keys": ssh: no key found`, err.Error())
requireEqual(t, 0, len(keys))
requireEqual(t, false, isAuthorized("testdata/invalid_authorized_keys", func(k ssh.PublicKey) bool { return true }))
})

t.Run("file not found", func(t *testing.T) {
keys, err := parseAuthorizedKeys("testdata/nope_authorized_keys")
requireEqual(t, `failed to parse "testdata/nope_authorized_keys": open testdata/nope_authorized_keys: no such file or directory`, err.Error())
requireEqual(t, 0, len(keys))
requireEqual(t, false, isAuthorized("testdata/nope_authorized_keys", func(k ssh.PublicKey) bool { return true }))
})
}

Expand All @@ -65,12 +59,18 @@ func TestWithAuthorizedKeys(t *testing.T) {

t.Run("invalid", func(t *testing.T) {
s := ssh.Server{}
requireEqual(
requireNoError(
t,
`failed to parse "testdata/invalid_authorized_keys": ssh: no key found`,
WithAuthorizedKeys("testdata/invalid_authorized_keys")(&s).Error(),
WithAuthorizedKeys("testdata/invalid_authorized_keys")(&s),
)
})

t.Run("file not found", func(t *testing.T) {
s := ssh.Server{}
if err := WithAuthorizedKeys("testdata/nope_authorized_keys")(&s); err == nil {
t.Fatal("expected an error, got nil")
}
})
}

func TestWithTrustedUserCAKeys(t *testing.T) {
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestWithTrustedUserCAKeys(t *testing.T) {
t.Run("invalid ca key", func(t *testing.T) {
s := &ssh.Server{}
if err := WithTrustedUserCAKeys("testdata/invalid-path")(s); err == nil {
t.Fatal("expedted an error, got nil")
t.Fatal("expected an error, got nil")
}
})

Expand Down

0 comments on commit d2ae592

Please sign in to comment.