From c01c4a46b1bad277c57528b7288e4bb313b87360 Mon Sep 17 00:00:00 2001 From: Agniva De Sarker Date: Wed, 5 Apr 2023 22:04:29 +0530 Subject: [PATCH] CP: MM-51768: Scrub username/password from SQL datasource (#22853) Automatic Merge --- config/database.go | 5 +++- config/database_test.go | 6 ++-- config/utils.go | 21 -------------- config/utils_test.go | 55 ------------------------------------ store/sqlstore/store.go | 4 ++- store/sqlstore/utils.go | 27 ++++++++++++++++++ store/sqlstore/utils_test.go | 43 ++++++++++++++++++++++++++++ 7 files changed, 80 insertions(+), 81 deletions(-) diff --git a/config/database.go b/config/database.go index e9f9328d38dd0..54d96d41aeb0b 100644 --- a/config/database.go +++ b/config/database.go @@ -406,7 +406,10 @@ func (ds *DatabaseStore) RemoveFile(name string) error { // String returns the path to the database backing the config, masking the password. func (ds *DatabaseStore) String() string { - return stripPassword(ds.originalDsn, ds.driverName) + // This is called during the running of MM, so we expect the parsing of DSN + // to be successful. + sanitized, _ := sqlstore.SanitizeDataSource(ds.driverName, ds.originalDsn) + return sanitized } // Close cleans up resources associated with the store. diff --git a/config/database_test.go b/config/database_test.go index 4eab71fc5c833..6954461a08202 100644 --- a/config/database_test.go +++ b/config/database_test.go @@ -1107,12 +1107,12 @@ func TestDatabaseStoreString(t *testing.T) { if *mainHelper.GetSQLSettings().DriverName == "postgres" { maskedDSN := ds.String() assert.True(t, strings.HasPrefix(maskedDSN, "postgres://")) - assert.True(t, strings.Contains(maskedDSN, "mmuser")) + assert.False(t, strings.Contains(maskedDSN, "mmuser")) assert.False(t, strings.Contains(maskedDSN, "mostest")) } else { maskedDSN := ds.String() - assert.True(t, strings.HasPrefix(maskedDSN, "mysql://")) - assert.True(t, strings.Contains(maskedDSN, "mmuser")) + assert.False(t, strings.HasPrefix(maskedDSN, "mysql://")) + assert.False(t, strings.Contains(maskedDSN, "mmuser")) assert.False(t, strings.Contains(maskedDSN, "mostest")) } } diff --git a/config/utils.go b/config/utils.go index 9faa69e738943..42371641be14c 100644 --- a/config/utils.go +++ b/config/utils.go @@ -179,27 +179,6 @@ func IsDatabaseDSN(dsn string) bool { strings.HasPrefix(dsn, "postgresql://") } -// stripPassword remove the password from a given DSN -func stripPassword(dsn, schema string) string { - prefix := schema + "://" - dsn = strings.TrimPrefix(dsn, prefix) - - i := strings.Index(dsn, ":") - j := strings.LastIndex(dsn, "@") - - // Return error if no @ sign is found - if j < 0 { - return "(omitted due to error parsing the DSN)" - } - - // Return back the input if no password is found - if i < 0 || i > j { - return prefix + dsn - } - - return prefix + dsn[:i+1] + dsn[j:] -} - func isJSONMap(data string) bool { var m map[string]any return json.Unmarshal([]byte(data), &m) == nil diff --git a/config/utils_test.go b/config/utils_test.go index 4788fac807bbb..7d9b9c11b1e96 100644 --- a/config/utils_test.go +++ b/config/utils_test.go @@ -197,61 +197,6 @@ func TestIsDatabaseDSN(t *testing.T) { } } -func TestStripPassword(t *testing.T) { - for name, test := range map[string]struct { - DSN string - Schema string - ExpectedOut string - }{ - "mysql": { - DSN: "mysql://mmuser:password@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s", - Schema: "mysql", - ExpectedOut: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s", - }, - "mysql idempotent": { - DSN: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s", - Schema: "mysql", - ExpectedOut: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s", - }, - "mysql: password with : and @": { - DSN: "mysql://mmuser:p:assw@ord@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s", - Schema: "mysql", - ExpectedOut: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s", - }, - "mysql: password with @ and :": { - DSN: "mysql://mmuser:pa@sswo:rd@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s", - Schema: "mysql", - ExpectedOut: "mysql://mmuser:@tcp(localhost:3306)/mattermost?charset=utf8mb4,utf8&readTimeout=30s", - }, - "postgres": { - DSN: "postgres://mmuser:password@localhost:5432/mattermost?sslmode=disable&connect_timeout=10", - Schema: "postgres", - ExpectedOut: "postgres://mmuser:@localhost:5432/mattermost?sslmode=disable&connect_timeout=10", - }, - "pipe": { - DSN: "mysql://user@unix(/path/to/socket)/dbname", - Schema: "mysql", - ExpectedOut: "mysql://user@unix(/path/to/socket)/dbname", - }, - "malformed without :": { - DSN: "postgres://mmuserpassword@localhost:5432/mattermost?sslmode=disable&connect_timeout=10", - Schema: "postgres", - ExpectedOut: "postgres://mmuserpassword@localhost:5432/mattermost?sslmode=disable&connect_timeout=10", - }, - "malformed without @": { - DSN: "postgres://mmuser:passwordlocalhost:5432/mattermost?sslmode=disable&connect_timeout=10", - Schema: "postgres", - ExpectedOut: "(omitted due to error parsing the DSN)", - }, - } { - t.Run(name, func(t *testing.T) { - out := stripPassword(test.DSN, test.Schema) - - assert.Equal(t, test.ExpectedOut, out) - }) - } -} - func TestIsJSONMap(t *testing.T) { tests := []struct { name string diff --git a/store/sqlstore/store.go b/store/sqlstore/store.go index 28a86886cfdb7..52b343591c61b 100644 --- a/store/sqlstore/store.go +++ b/store/sqlstore/store.go @@ -235,7 +235,9 @@ func SetupConnection(connType string, dataSource string, settings *model.SqlSett } for i := 0; i < DBPingAttempts; i++ { - mlog.Info("Pinging SQL", mlog.String("database", connType)) + // At this point, we have passed sql.Open, so we deliberately ignore any errors. + sanitized, _ := SanitizeDataSource(*settings.DriverName, dataSource) + mlog.Info("Pinging SQL", mlog.String("database", connType), mlog.String("dataSource", sanitized)) ctx, cancel := context.WithTimeout(context.Background(), DBPingTimeoutSecs*time.Second) defer cancel() err = db.PingContext(ctx) diff --git a/store/sqlstore/utils.go b/store/sqlstore/utils.go index e7fe4958d9d47..bdf43abbca421 100644 --- a/store/sqlstore/utils.go +++ b/store/sqlstore/utils.go @@ -5,6 +5,7 @@ package sqlstore import ( "database/sql" + "errors" "io" "net/url" "strconv" @@ -205,3 +206,29 @@ func ResetReadTimeout(dataSource string) (string, error) { config.ReadTimeout = 0 return config.FormatDSN(), nil } + +func SanitizeDataSource(driverName, dataSource string) (string, error) { + switch driverName { + case model.DatabaseDriverPostgres: + u, err := url.Parse(dataSource) + if err != nil { + return "", err + } + u.User = url.UserPassword("****", "****") + params := u.Query() + params.Del("user") + params.Del("password") + u.RawQuery = params.Encode() + return u.String(), nil + case model.DatabaseDriverMysql: + cfg, err := mysql.ParseDSN(dataSource) + if err != nil { + return "", err + } + cfg.User = "****" + cfg.Passwd = "****" + return cfg.FormatDSN(), nil + default: + return "", errors.New("invalid drivername. Not postgres or mysql.") + } +} diff --git a/store/sqlstore/utils_test.go b/store/sqlstore/utils_test.go index 811ebf001a301..96468323bbd12 100644 --- a/store/sqlstore/utils_test.go +++ b/store/sqlstore/utils_test.go @@ -6,6 +6,7 @@ package sqlstore import ( "testing" + "github.com/mattermost/mattermost-server/v6/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -160,3 +161,45 @@ func TestAppendMultipleStatementsFlag(t *testing.T) { }) } } + +func TestSanitizeDataSource(t *testing.T) { + t.Run(model.DatabaseDriverPostgres, func(t *testing.T) { + testCases := []struct { + Original string + Sanitized string + }{ + { + "postgres://mmuser:mostest@localhost/dummy?sslmode=disable", + "postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable", + }, + { + "postgres://localhost/dummy?sslmode=disable&user=mmuser&password=mostest", + "postgres://%2A%2A%2A%2A:%2A%2A%2A%2A@localhost/dummy?sslmode=disable", + }, + } + driver := model.DatabaseDriverPostgres + for _, tc := range testCases { + out, err := SanitizeDataSource(driver, tc.Original) + require.NoError(t, err) + assert.Equal(t, tc.Sanitized, out) + } + }) + + t.Run(model.DatabaseDriverMysql, func(t *testing.T) { + testCases := []struct { + Original string + Sanitized string + }{ + { + "mmuser:mostest@tcp(localhost:3306)/mattermost_test?charset=utf8mb4,utf8&readTimeout=30s&writeTimeout=30s", + "****:****@tcp(localhost:3306)/mattermost_test?readTimeout=30s&writeTimeout=30s&charset=utf8mb4%2Cutf8", + }, + } + driver := model.DatabaseDriverMysql + for _, tc := range testCases { + out, err := SanitizeDataSource(driver, tc.Original) + require.NoError(t, err) + assert.Equal(t, tc.Sanitized, out) + } + }) +}