diff --git a/statement_parser.go b/statement_parser.go index 74e52684..20162545 100644 --- a/statement_parser.go +++ b/statement_parser.go @@ -683,7 +683,9 @@ func (p *statementParser) findParams(sql string) (string, []string, error) { return "", nil, err } info := p.detectStatementType(sql) - p.statementsCache.Add(sql, &statementsCacheEntry{sql: namedParamsSql, params: params, info: info}) + cachedParams := make([]string, len(params)) + copy(cachedParams, params) + p.statementsCache.Add(sql, &statementsCacheEntry{sql: namedParamsSql, params: cachedParams, info: info}) return namedParamsSql, params, nil } } diff --git a/statement_parser_test.go b/statement_parser_test.go index bd7bd6fa..f0398238 100644 --- a/statement_parser_test.go +++ b/statement_parser_test.go @@ -2378,6 +2378,27 @@ func TestDetectStatementType(t *testing.T) { } } +func TestCachedParamsAreImmutable(t *testing.T) { + parser, err := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) + if err != nil { + t.Fatal(err) + } + for n := 0; n < 2; n++ { + _, params, err := parser.findParams("select * from test where id=?") + if err != nil { + t.Fatal(err) + } + if g, w := len(params), 1; g != w { + t.Fatalf("params length mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := params[0], "p1"; g != w { + t.Fatalf("param mismatch\n Got: %v\nWant: %v", g, w) + } + // Modify the params we got from the parser and verify that this does not modify the cached params. + params[0] = "test" + } +} + func BenchmarkDetectStatementTypeWithoutCache(b *testing.B) { parser, err := newStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 0) if err != nil {