Skip to content

Commit

Permalink
CP: MM-51768: Scrub username/password from SQL datasource (#22853)
Browse files Browse the repository at this point in the history
Automatic Merge
  • Loading branch information
agnivade committed Apr 5, 2023
1 parent dcab64f commit c01c4a4
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 81 deletions.
5 changes: 4 additions & 1 deletion config/database.go
Expand Up @@ -406,7 +406,10 @@ func (ds *DatabaseStore) RemoveFile(name string) error {

// String returns the path to the database backing the config, masking the password.
func (ds *DatabaseStore) String() string {
return stripPassword(ds.originalDsn, ds.driverName)
// This is called during the running of MM, so we expect the parsing of DSN
// to be successful.
sanitized, _ := sqlstore.SanitizeDataSource(ds.driverName, ds.originalDsn)
return sanitized
}

// Close cleans up resources associated with the store.
Expand Down
6 changes: 3 additions & 3 deletions config/database_test.go
Expand Up @@ -1107,12 +1107,12 @@ func TestDatabaseStoreString(t *testing.T) {
if *mainHelper.GetSQLSettings().DriverName == "postgres" {
maskedDSN := ds.String()
assert.True(t, strings.HasPrefix(maskedDSN, "postgres://"))
assert.True(t, strings.Contains(maskedDSN, "mmuser"))
assert.False(t, strings.Contains(maskedDSN, "mmuser"))
assert.False(t, strings.Contains(maskedDSN, "mostest"))
} else {
maskedDSN := ds.String()
assert.True(t, strings.HasPrefix(maskedDSN, "mysql://"))
assert.True(t, strings.Contains(maskedDSN, "mmuser"))
assert.False(t, strings.HasPrefix(maskedDSN, "mysql://"))
assert.False(t, strings.Contains(maskedDSN, "mmuser"))
assert.False(t, strings.Contains(maskedDSN, "mostest"))
}
}
Expand Down
21 changes: 0 additions & 21 deletions config/utils.go
Expand Up @@ -179,27 +179,6 @@ func IsDatabaseDSN(dsn string) bool {
strings.HasPrefix(dsn, "postgresql://")
}

// stripPassword remove the password from a given DSN
func stripPassword(dsn, schema string) string {
prefix := schema + "://"
dsn = strings.TrimPrefix(dsn, prefix)

i := strings.Index(dsn, ":")
j := strings.LastIndex(dsn, "@")

// Return error if no @ sign is found
if j < 0 {
return "(omitted due to error parsing the DSN)"
}

// Return back the input if no password is found
if i < 0 || i > j {
return prefix + dsn
}

return prefix + dsn[:i+1] + dsn[j:]
}

func isJSONMap(data string) bool {
var m map[string]any
return json.Unmarshal([]byte(data), &m) == nil
Expand Down
55 changes: 0 additions & 55 deletions config/utils_test.go
Expand Up @@ -197,61 +197,6 @@ func TestIsDatabaseDSN(t *testing.T) {
}
}

func TestStripPassword(t *testing.T) {
for name, test := range map[string]struct {
DSN string
Schema string
ExpectedOut string
}{
"mysql": {
DSN: "mysql://mmuser:password@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s",
Schema: "mysql",
ExpectedOut: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s",
},
"mysql idempotent": {
DSN: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s",
Schema: "mysql",
ExpectedOut: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s",
},
"mysql: password with : and @": {
DSN: "mysql://mmuser:p:assw@ord@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s",
Schema: "mysql",
ExpectedOut: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s",
},
"mysql: password with @ and :": {
DSN: "mysql://mmuser:pa@sswo:rd@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s",
Schema: "mysql",
ExpectedOut: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s",
},
"postgres": {
DSN: "postgres://mmuser:password@localhost:5432/mattermost?sslmode=disable&connect_timeout=10",
Schema: "postgres",
ExpectedOut: "postgres://mmuser:@localhost:5432/mattermost?sslmode=disable&connect_timeout=10",
},
"pipe": {
DSN: "mysql://user@unix(/path/to/socket)/dbname",
Schema: "mysql",
ExpectedOut: "mysql://user@unix(/path/to/socket)/dbname",
},
"malformed without :": {
DSN: "postgres://mmuserpassword@localhost:5432/mattermost?sslmode=disable&connect_timeout=10",
Schema: "postgres",
ExpectedOut: "postgres://mmuserpassword@localhost:5432/mattermost?sslmode=disable&connect_timeout=10",
},
"malformed without @": {
DSN: "postgres://mmuser:passwordlocalhost:5432/mattermost?sslmode=disable&connect_timeout=10",
Schema: "postgres",
ExpectedOut: "(omitted due to error parsing the DSN)",
},
} {
t.Run(name, func(t *testing.T) {
out := stripPassword(test.DSN, test.Schema)

assert.Equal(t, test.ExpectedOut, out)
})
}
}

func TestIsJSONMap(t *testing.T) {
tests := []struct {
name string
Expand Down
4 changes: 3 additions & 1 deletion store/sqlstore/store.go
Expand Up @@ -235,7 +235,9 @@ func SetupConnection(connType string, dataSource string, settings *model.SqlSett
}

for i := 0; i < DBPingAttempts; i++ {
mlog.Info("Pinging SQL", mlog.String("database", connType))
// At this point, we have passed sql.Open, so we deliberately ignore any errors.
sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource)
mlog.Info("Pinging SQL", mlog.String("database", connType), mlog.String("dataSource", sanitized))
ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second)
defer cancel()
err = db.PingContext(ctx)
Expand Down
27 changes: 27 additions & 0 deletions store/sqlstore/utils.go
Expand Up @@ -5,6 +5,7 @@ package sqlstore

import (
"database/sql"
"errors"
"io"
"net/url"
"strconv"
Expand Down Expand Up @@ -205,3 +206,29 @@ func ResetReadTimeout(dataSource string) (string, error) {
config.ReadTimeout = 0
return config.FormatDSN(), nil
}

func SanitizeDataSource(driverName, dataSource string) (string, error) {
switch driverName {
case model.DatabaseDriverPostgres:
u, err := url.Parse(dataSource)
if err != nil {
return "", err
}
u.User = url.UserPassword("****", "****")
params := u.Query()
params.Del("user")
params.Del("password")
u.RawQuery = params.Encode()
return u.String(), nil
case model.DatabaseDriverMysql:
cfg, err := mysql.ParseDSN(dataSource)
if err != nil {
return "", err
}
cfg.User = "****"
cfg.Passwd = "****"
return cfg.FormatDSN(), nil
default:
return "", errors.New("invalid drivername. Not postgres or mysql.")
}
}
43 changes: 43 additions & 0 deletions store/sqlstore/utils_test.go
Expand Up @@ -6,6 +6,7 @@ package sqlstore
import (
"testing"

"github.com/mattermost/mattermost-server/v6/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -160,3 +161,45 @@ func TestAppendMultipleStatementsFlag(t *testing.T) {
})
}
}

func TestSanitizeDataSource(t *testing.T) {
t.Run(model.DatabaseDriverPostgres, func(t *testing.T) {
testCases := []struct {
Original string
Sanitized string
}{
{
"postgres://mmuser:mostest@localhost/dummy?sslmode=disable",
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
},
{
"postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest",
"postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable",
},
}
driver := model.DatabaseDriverPostgres
for _, tc := range testCases {
out, err := SanitizeDataSource(driver, tc.Original)
require.NoError(t, err)
assert.Equal(t, tc.Sanitized, out)
}
})

t.Run(model.DatabaseDriverMysql, func(t *testing.T) {
testCases := []struct {
Original string
Sanitized string
}{
{
"mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s",
"****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8",
},
}
driver := model.DatabaseDriverMysql
for _, tc := range testCases {
out, err := SanitizeDataSource(driver, tc.Original)
require.NoError(t, err)
assert.Equal(t, tc.Sanitized, out)
}
})
}

0 comments on commit c01c4a4

Please sign in to comment.