Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release-19.2: sql: allow pgwire auth methods to specify a cleanup func #49655

Merged
merged 1 commit into from May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions pkg/acceptance/compose/gss/psql/gss_test.go
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"testing"

"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/lib/pq"
"github.com/pkg/errors"
)
Expand Down Expand Up @@ -108,6 +109,43 @@ func TestGSS(t *testing.T) {
}
}

func TestGSSFileDescriptorCount(t *testing.T) {
// When the docker-compose.yml added a ulimit for the cockroach
// container the open file count would just stop there, it wouldn't
// cause cockroach to panic or error like I had hoped since it would
// allow a test to assert that multiple gss connections didn't leak
// file descriptors. Another possibility would be to have something
// track the open file count in the cockroach container, but that seems
// brittle and probably not worth the effort. However this test is
// useful when doing manual tracking of file descriptor count.
t.Skip("skip")

rootConnector, err := pq.NewConnector("user=root sslmode=require")
if err != nil {
t.Fatal(err)
}
rootDB := gosql.OpenDB(rootConnector)
defer rootDB.Close()

if _, err := rootDB.Exec(`SET CLUSTER SETTING server.host_based_authentication.configuration = $1`, "host all all all gss include_realm=0"); err != nil {
t.Fatal(err)
}
const user = "tester"
if _, err := rootDB.Exec(fmt.Sprintf(`CREATE USER IF NOT EXISTS '%s'`, user)); err != nil {
t.Fatal(err)
}

start := timeutil.Now()
for i := 0; i < 1000; i++ {
fmt.Println(i, timeutil.Since(start))
out, err := exec.Command("psql", "-c", "SELECT 1", "-U", user).CombinedOutput()
if IsError(err, "GSS authentication requires an enterprise license") {
t.Log(string(out))
t.Fatal(err)
}
}
}

func IsError(err error, re string) bool {
if err == nil && re == "" {
return true
Expand Down
2 changes: 2 additions & 0 deletions pkg/acceptance/compose_test.go
Expand Up @@ -19,8 +19,10 @@ import (
func TestComposeGSS(t *testing.T) {
out, err := exec.Command(
"docker-compose",
"--no-ansi",
"-f", filepath.Join("compose", "gss", "docker-compose.yml"),
"up",
"--force-recreate",
"--build",
"--exit-code-from", "psql",
).CombinedOutput()
Expand Down
39 changes: 27 additions & 12 deletions pkg/ccl/gssapiccl/gssapi.go
Expand Up @@ -39,7 +39,7 @@ const (
)

// authGSS performs GSS authentication. See:
// https:github.com/postgres/postgres/blob/0f9cdd7dca694d487ab663d463b308919f591c02/src/backend/libpq/auth.c#L1090
// https://github.com/postgres/postgres/blob/0f9cdd7dca694d487ab663d463b308919f591c02/src/backend/libpq/auth.c#L1090
func authGSS(
c pgwire.AuthConn,
tlsState tls.ConnectionState,
Expand All @@ -48,7 +48,7 @@ func authGSS(
execCfg *sql.ExecutorConfig,
entry *hba.Entry,
) (security.UserAuthHook, error) {
return func(requestedUser string, clientConnection bool) error {
return func(requestedUser string, clientConnection bool) (func(), error) {
var (
majStat, minStat, lminS, gflags C.OM_uint32
gbuf C.gss_buffer_desc
Expand All @@ -62,13 +62,29 @@ func authGSS(
)

if err = c.SendAuthRequest(authTypeGSS, nil); err != nil {
return err
return nil, err
}

// This cleanup function must be called at the
// "completion of a communications session", not
// merely at the end of an authentication init. See
// https://tools.ietf.org/html/rfc2744.html, section
// `1. Introduction`, stage `d`:
//
// At the completion of a communications session (which
// may extend across several transport connections),
// each application calls a GSS-API routine to delete
// the security context.
//
// See https://github.com/postgres/postgres/blob/f4d59369d2ddf0ad7850112752ec42fd115825d4/src/backend/libpq/pqcomm.c#L269
connClose := func() {
C.gss_delete_sec_context(&lminS, &contextHandle, C.GSS_C_NO_BUFFER)
}

for {
token, err = c.GetPwdData()
if err != nil {
return err
return connClose, err
}

gbuf.length = C.ulong(len(token))
Expand All @@ -93,12 +109,11 @@ func authGSS(
outputBytes := C.GoBytes(outputToken.value, C.int(outputToken.length))
C.gss_release_buffer(&lminS, &outputToken)
if err = c.SendAuthRequest(authTypeGSSContinue, outputBytes); err != nil {
return err
return connClose, err
}
}
if majStat != C.GSS_S_COMPLETE && majStat != C.GSS_S_CONTINUE_NEEDED {
C.gss_delete_sec_context(&lminS, &contextHandle, C.GSS_C_NO_BUFFER)
return gssError("accepting GSS security context failed", majStat, minStat)
return connClose, gssError("accepting GSS security context failed", majStat, minStat)
}
if majStat != C.GSS_S_CONTINUE_NEEDED {
break
Expand All @@ -107,7 +122,7 @@ func authGSS(

majStat = C.gss_display_name(&minStat, srcName, &gbuf, nil)
if majStat != C.GSS_S_COMPLETE {
return gssError("retrieving GSS user name failed", majStat, minStat)
return connClose, gssError("retrieving GSS user name failed", majStat, minStat)
}
gssUser := C.GoStringN((*C.char)(gbuf.value), C.int(gbuf.length))
C.gss_release_buffer(&lminS, &gbuf)
Expand All @@ -125,25 +140,25 @@ func authGSS(
}
}
if !matched {
return errors.Errorf("GSSAPI realm (%s) didn't match any configured realm", realm)
return connClose, errors.Errorf("GSSAPI realm (%s) didn't match any configured realm", realm)
}
}
if entry.GetOption("include_realm") != "1" {
gssUser = gssUser[:idx]
}
} else if len(realms) > 0 {
return errors.New("GSSAPI did not return realm but realm matching was requested")
return connClose, errors.New("GSSAPI did not return realm but realm matching was requested")
}

if !strings.EqualFold(gssUser, requestedUser) {
return errors.Errorf("requested user is %s, but GSSAPI auth is for %s", requestedUser, gssUser)
return connClose, errors.Errorf("requested user is %s, but GSSAPI auth is for %s", requestedUser, gssUser)
}

// Do the license check last so that administrators are able to test whether
// their GSS configuration is correct. That is, the presence of this error
// message means they have a correctly functioning GSS/Kerberos setup,
// but now need to enable enterprise features.
return utilccl.CheckEnterpriseEnabled(execCfg.Settings, execCfg.ClusterID(), execCfg.Organization(), "GSS authentication")
return connClose, utilccl.CheckEnterpriseEnabled(execCfg.Settings, execCfg.ClusterID(), execCfg.Organization(), "GSS authentication")
}, nil
}

Expand Down
31 changes: 16 additions & 15 deletions pkg/security/auth.go
Expand Up @@ -24,8 +24,9 @@ const (
)

// UserAuthHook authenticates a user based on their username and whether their
// connection originates from a client or another node in the cluster.
type UserAuthHook func(string, bool) error
// connection originates from a client or another node in the cluster. It
// returns an optional func that is run at connection close.
type UserAuthHook func(string, bool) (connClose func(), _ error)

// GetCertificateUser extract the username from a client certificate.
func GetCertificateUser(tlsState *tls.ConnectionState) (string, error) {
Expand Down Expand Up @@ -54,58 +55,58 @@ func UserAuthCertHook(insecureMode bool, tlsState *tls.ConnectionState) (UserAut
}
}

return func(requestedUser string, clientConnection bool) error {
return func(requestedUser string, clientConnection bool) (func(), error) {
// TODO(marc): we may eventually need stricter user syntax rules.
if len(requestedUser) == 0 {
return errors.New("user is missing")
return nil, errors.New("user is missing")
}

if !clientConnection && requestedUser != NodeUser {
return errors.Errorf("user %s is not allowed", requestedUser)
return nil, errors.Errorf("user %s is not allowed", requestedUser)
}

// If running in insecure mode, we have nothing to verify it against.
if insecureMode {
return nil
return nil, nil
}

// The client certificate user must match the requested user,
// except if the certificate user is NodeUser, which is allowed to
// act on behalf of all other users.
if !(certUser == NodeUser || certUser == requestedUser) {
return errors.Errorf("requested user is %s, but certificate is for %s", requestedUser, certUser)
return nil, errors.Errorf("requested user is %s, but certificate is for %s", requestedUser, certUser)
}

return nil
return nil, nil
}, nil
}

// UserAuthPasswordHook builds an authentication hook based on the security
// mode, password, and its potentially matching hash.
func UserAuthPasswordHook(insecureMode bool, password string, hashedPassword []byte) UserAuthHook {
return func(requestedUser string, clientConnection bool) error {
return func(requestedUser string, clientConnection bool) (func(), error) {
if len(requestedUser) == 0 {
return errors.New("user is missing")
return nil, errors.New("user is missing")
}

if !clientConnection {
return errors.New("password authentication is only available for client connections")
return nil, errors.New("password authentication is only available for client connections")
}

if insecureMode {
return nil
return nil, nil
}

if requestedUser == RootUser {
return errors.Errorf("user %s must use certificate authentication instead of password authentication", RootUser)
return nil, errors.Errorf("user %s must use certificate authentication instead of password authentication", RootUser)
}

// If the requested user has an empty password, disallow authentication.
if len(password) == 0 || CompareHashAndPassword(hashedPassword, password) != nil {
return errors.Errorf(ErrPasswordUserAuthFailed, requestedUser)
return nil, errors.Errorf(ErrPasswordUserAuthFailed, requestedUser)
}

return nil
return nil, nil
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/security/auth_test.go
Expand Up @@ -109,11 +109,11 @@ func TestAuthenticationHook(t *testing.T) {
if err != nil {
continue
}
err = hook(tc.username, true /*public*/)
_, err = hook(tc.username, true /*public*/)
if (err == nil) != tc.publicHookSuccess {
t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.publicHookSuccess, err)
}
err = hook(tc.username, false /*not public*/)
_, err = hook(tc.username, false /*not public*/)
if (err == nil) != tc.privateHookSuccess {
t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.privateHookSuccess, err)
}
Expand Down
30 changes: 17 additions & 13 deletions pkg/sql/pgwire/conn.go
Expand Up @@ -514,6 +514,7 @@ func (c *conn) processCommandsAsync(
var retErr error
var connHandler sql.ConnectionHandler
var authOK bool
var connCloseAuthHandler func()
defer func() {
// Release resources, if we still own them.
if reservedOwned {
Expand Down Expand Up @@ -552,6 +553,9 @@ func (c *conn) processCommandsAsync(
if !authOK {
ac.AuthFail(retErr)
}
if connCloseAuthHandler != nil {
connCloseAuthHandler()
}
// Inform the connection goroutine of success or failure.
retCh <- retErr
}()
Expand All @@ -563,7 +567,7 @@ func (c *conn) processCommandsAsync(
return
}
} else {
if retErr = c.handleAuthentication(
if connCloseAuthHandler, retErr = c.handleAuthentication(
ctx, ac, authOpt.insecure, authOpt.ie, authOpt.auth,
sqlServer.GetExecutorConfig(),
); retErr != nil {
Expand Down Expand Up @@ -1473,7 +1477,7 @@ func (c *conn) handleAuthentication(
ie *sql.InternalExecutor,
auth *hba.Conf,
execCfg *sql.ExecutorConfig,
) error {
) (connClose func(), _ error) {
sendError := func(err error) error {
_ /* err */ = writeErr(ctx, &execCfg.Settings.SV, err, &c.msgBuilder, c.conn)
return err
Expand All @@ -1485,10 +1489,10 @@ func (c *conn) handleAuthentication(
ctx, ie, &c.metrics.SQLMemMetrics, c.sessionArgs.User,
)
if err != nil {
return sendError(err)
return nil, sendError(err)
}
if !exists {
return sendError(errors.Errorf(security.ErrPasswordUserAuthFailed, c.sessionArgs.User))
return nil, sendError(errors.Errorf(security.ErrPasswordUserAuthFailed, c.sessionArgs.User))
}

if tlsConn, ok := c.conn.(*readTimeoutConn).Conn.(*tls.Conn); ok {
Expand All @@ -1507,7 +1511,7 @@ func (c *conn) handleAuthentication(
} else {
addr, _, err := net.SplitHostPort(c.conn.RemoteAddr().String())
if err != nil {
return sendError(err)
return nil, sendError(err)
}
ip := net.ParseIP(addr)
for _, entry := range auth.Entries {
Expand All @@ -1518,10 +1522,10 @@ func (c *conn) handleAuthentication(
}
case hba.String:
if !a.IsSpecial("all") {
return sendError(errors.Errorf("unexpected %s address: %q", serverHBAConfSetting, a.Value))
return nil, sendError(errors.Errorf("unexpected %s address: %q", serverHBAConfSetting, a.Value))
}
default:
return sendError(errors.Errorf("unexpected address type %T", a))
return nil, sendError(errors.Errorf("unexpected address type %T", a))
}
match := false
for _, u := range entry.User {
Expand All @@ -1539,28 +1543,28 @@ func (c *conn) handleAuthentication(
}
methodFn = hbaAuthMethods[entry.Method]
if methodFn == nil {
return sendError(errors.Errorf("unknown auth method %s", entry.Method))
return nil, sendError(errors.Errorf("unknown auth method %s", entry.Method))
}
hbaEntry = &entry
break
}
if methodFn == nil {
return sendError(errors.Errorf("no %s entry for host %q, user %q", serverHBAConfSetting, addr, c.sessionArgs.User))
return nil, sendError(errors.Errorf("no %s entry for host %q, user %q", serverHBAConfSetting, addr, c.sessionArgs.User))
}
}

authenticationHook, err := methodFn(ac, tlsState, insecure, hashedPassword, execCfg, hbaEntry)
if err != nil {
return sendError(err)
return nil, sendError(err)
}
if err := authenticationHook(c.sessionArgs.User, true /* public */); err != nil {
return sendError(err)
if connClose, err = authenticationHook(c.sessionArgs.User, true /* public */); err != nil {
return connClose, sendError(err)
}
}

c.msgBuilder.initMsg(pgwirebase.ServerMsgAuth)
c.msgBuilder.putInt32(authOK)
return c.msgBuilder.finishMsg(c.conn)
return connClose, c.msgBuilder.finishMsg(c.conn)
}

const serverHBAConfSetting = "server.host_based_authentication.configuration"
Expand Down