diff --git a/go/database/sql/connection.go b/go/database/sql/connection.go new file mode 100644 index 00000000..2dbb5f1d --- /dev/null +++ b/go/database/sql/connection.go @@ -0,0 +1,151 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "database/sql/driver" + "fmt" + "runtime/debug" + "strings" + + "github.com/google/sqlcommenter/go/core" +) + +var attemptedToAutosetApplication = false + +type sqlCommenterConn struct { + driver.Conn + options core.CommenterOptions +} + +func newSQLCommenterConn(conn driver.Conn, options core.CommenterOptions) *sqlCommenterConn { + return &sqlCommenterConn{ + Conn: conn, + options: options, + } +} + +func (s *sqlCommenterConn) Query(query string, args []driver.Value) (driver.Rows, error) { + queryer, ok := s.Conn.(driver.Queryer) + if !ok { + return nil, driver.ErrSkip + } + ctx := context.Background() + commentedQuery := s.withComment(ctx, query) + return queryer.Query(commentedQuery, args) +} + +func (s *sqlCommenterConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + queryer, ok := s.Conn.(driver.QueryerContext) + if !ok { + return nil, driver.ErrSkip + } + commentedQuery := s.withComment(ctx, query) + return queryer.QueryContext(ctx, commentedQuery, args) +} + +func (s *sqlCommenterConn) Exec(query string, args []driver.Value) (driver.Result, error) { + execor, ok := s.Conn.(driver.Execer) + if !ok { + return nil, driver.ErrSkip + } + ctx := context.Background() + commentedQuery := s.withComment(ctx, query) + return execor.Exec(commentedQuery, args) +} + +func (s *sqlCommenterConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + execor, ok := s.Conn.(driver.ExecerContext) + if !ok { + return nil, driver.ErrSkip + } + commentedQuery := s.withComment(ctx, query) + return execor.ExecContext(ctx, commentedQuery, args) +} + +func (s *sqlCommenterConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { + preparer, ok := s.Conn.(driver.ConnPrepareContext) + if !ok { + return nil, driver.ErrSkip + } + commentedQuery := s.withComment(ctx, query) + return preparer.PrepareContext(ctx, commentedQuery) +} + +func (s *sqlCommenterConn) Raw() driver.Conn { + return s.Conn +} + +// ***** Commenter Functions ***** + +func (conn *sqlCommenterConn) withComment(ctx context.Context, query string) string { + var commentsMap = map[string]string{} + query = strings.TrimSpace(query) + config := conn.options.Config + + // Sorted alphabetically + if config.EnableAction && (ctx.Value(core.Action) != nil) { + commentsMap[core.Action] = ctx.Value(core.Action).(string) + } + + // `driver` information should not be coming from framework. + // So, explicitly adding that here. + if config.EnableDBDriver { + commentsMap[core.Driver] = fmt.Sprintf("database/sql:%s", conn.options.Tags.DriverName) + } + + if config.EnableFramework && (ctx.Value(core.Framework) != nil) { + commentsMap[core.Framework] = ctx.Value(core.Framework).(string) + } + + if config.EnableRoute && (ctx.Value(core.Route) != nil) { + commentsMap[core.Route] = ctx.Value(core.Route).(string) + } + + if config.EnableTraceparent { + carrier := core.ExtractTraceparent(ctx) + if val, ok := carrier["traceparent"]; ok { + commentsMap[core.Traceparent] = val + } + } + + if config.EnableApplication { + if !attemptedToAutosetApplication && conn.options.Tags.Application == "" { + attemptedToAutosetApplication = true + bi, ok := debug.ReadBuildInfo() + if ok { + conn.options.Tags.Application = bi.Path + } + } + if conn.options.Tags.Application != "" { + commentsMap[core.Application] = conn.options.Tags.Application + } + } + + var commentsString string = "" + if len(commentsMap) > 0 { // Converts comments map to string and appends it to query + commentsString = fmt.Sprintf("/*%s*/", core.ConvertMapToComment(commentsMap)) + } + + // A semicolon at the end of the SQL statement means the query ends there. + // We need to insert the comment before that to be considered as part of the SQL statemtent. + if query[len(query)-1:] == ";" { + return fmt.Sprintf("%s%s;", strings.TrimSuffix(query, ";"), commentsString) + } + return fmt.Sprintf("%s%s", query, commentsString) +} + +// ***** Commenter Functions ***** diff --git a/go/database/sql/connection_test.go b/go/database/sql/connection_test.go new file mode 100644 index 00000000..f5ed0f82 --- /dev/null +++ b/go/database/sql/connection_test.go @@ -0,0 +1,167 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "testing" + + "github.com/google/sqlcommenter/go/core" +) + +func TestWithComment_NoContext(t *testing.T) { + testBasicConn := &mockConn{} + testCases := []struct { + desc string + commenterOptions core.CommenterOptions + query string + wantQuery string + }{ + { + desc: "empty commenter options", + commenterOptions: core.CommenterOptions{}, + query: "SELECT 1;", + wantQuery: "SELECT 1;", + }, + { + desc: "only enable DBDriver", + commenterOptions: core.CommenterOptions{ + Config: core.CommenterConfig{EnableDBDriver: true}, + }, + query: "SELECT 1;", + wantQuery: "SELECT 1/*db_driver=database%2Fsql%3A*/;", + }, + { + desc: "enable DBDriver and pass static tag driver name", + commenterOptions: core.CommenterOptions{ + Config: core.CommenterConfig{EnableDBDriver: true}, + Tags: core.StaticTags{DriverName: "postgres"}, + }, + query: "SELECT 1;", + wantQuery: "SELECT 1/*db_driver=database%2Fsql%3Apostgres*/;", + }, + { + desc: "enable DBDriver and pass all static tags", + commenterOptions: core.CommenterOptions{ + Config: core.CommenterConfig{EnableDBDriver: true}, + Tags: core.StaticTags{DriverName: "postgres", Application: "app-1"}, + }, + query: "SELECT 1;", + wantQuery: "SELECT 1/*db_driver=database%2Fsql%3Apostgres*/;", + }, + { + desc: "enable other tags and pass all static tags", + commenterOptions: core.CommenterOptions{ + Config: core.CommenterConfig{EnableDBDriver: true, EnableApplication: true, EnableFramework: true}, + Tags: core.StaticTags{DriverName: "postgres", Application: "app-1"}, + }, + query: "SELECT 1;", + wantQuery: "SELECT 1/*application=app-1,db_driver=database%2Fsql%3Apostgres*/;", + }, + } + for _, tc := range testCases { + testConn := newSQLCommenterConn(testBasicConn, tc.commenterOptions) + ctx := context.Background() + if got, want := testConn.withComment(ctx, tc.query), tc.wantQuery; got != want { + t.Errorf("testConn.withComment(ctx, %q) = %q, want = %q", tc.query, got, want) + } + } +} + +func TestWithComment_WithContext(t *testing.T) { + testBasicConn := &mockConn{} + testCases := []struct { + desc string + commenterOptions core.CommenterOptions + ctx context.Context + query string + wantQuery string + }{ + { + desc: "empty commenter options", + commenterOptions: core.CommenterOptions{}, + ctx: getContextWithKeyValue( + map[string]string{ + "route": "listData", + "framework": "custom-golang", + }, + ), + query: "SELECT 1;", + wantQuery: "SELECT 1;", + }, + { + desc: "only all options but context has few tags", + commenterOptions: core.CommenterOptions{ + Config: core.CommenterConfig{ + EnableDBDriver: true, + EnableRoute: true, + EnableFramework: true, + EnableController: true, + EnableAction: true, + EnableTraceparent: true, + EnableApplication: true, + }, + Tags: core.StaticTags{DriverName: "postgres", Application: "app-1"}, + }, + ctx: getContextWithKeyValue( + map[string]string{ + "route": "listData", + "framework": "custom-golang", + }, + ), + query: "SELECT 1;", + wantQuery: "SELECT 1/*application=app-1,db_driver=database%2Fsql%3Apostgres,framework=custom-golang,route=listData*/;", + }, + { + desc: "only all options but context contains all tags", + commenterOptions: core.CommenterOptions{ + Config: core.CommenterConfig{ + EnableDBDriver: true, + EnableRoute: true, + EnableFramework: true, + EnableController: true, + EnableAction: true, + EnableTraceparent: true, + EnableApplication: true, + }, + Tags: core.StaticTags{DriverName: "postgres", Application: "app-1"}, + }, + ctx: getContextWithKeyValue( + map[string]string{ + "route": "listData", + "framework": "custom-golang", + "controller": "custom-controller", + "action": "any action", + }, + ), + query: "SELECT 1;", + wantQuery: "SELECT 1/*action=any+action,application=app-1,db_driver=database%2Fsql%3Apostgres,framework=custom-golang,route=listData*/;", + }, + } + for _, tc := range testCases { + testConn := newSQLCommenterConn(testBasicConn, tc.commenterOptions) + if got, want := testConn.withComment(tc.ctx, tc.query), tc.wantQuery; got != want { + t.Errorf("testConn.withComment(ctx, %q) = %q, want = %q", tc.query, got, want) + } + } +} + +func getContextWithKeyValue(vals map[string]string) context.Context { + ctx := context.Background() + for k, v := range vals { + ctx = context.WithValue(ctx, k, v) + } + return ctx +} diff --git a/go/database/sql/go-sql.go b/go/database/sql/go-sql.go deleted file mode 100644 index 5637b513..00000000 --- a/go/database/sql/go-sql.go +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "context" - "database/sql" - "fmt" - "runtime/debug" - "strings" - - "github.com/google/sqlcommenter/go/core" -) - -var attemptedToAutosetApplication = false - -type DB struct { - *sql.DB - driverName string - options core.CommenterOptions - application string -} - -func Open(driverName string, dataSourceName string, options core.CommenterOptions) (*DB, error) { - db, err := sql.Open(driverName, dataSourceName) - return &DB{DB: db, driverName: driverName, options: options, application: options.Application}, err -} - -// ***** Query Functions ***** - -func (db *DB) Query(query string, args ...any) (*sql.Rows, error) { - return db.DB.Query(db.withComment(context.Background(), query), args...) -} - -func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { - return db.DB.QueryRow(db.withComment(context.Background(), query), args...) -} - -func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { - return db.DB.QueryContext(ctx, db.withComment(ctx, query), args...) -} - -func (db *DB) Exec(query string, args ...any) (sql.Result, error) { - return db.DB.Exec(db.withComment(context.Background(), query), args...) -} - -func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { - return db.DB.ExecContext(ctx, db.withComment(ctx, query), args...) -} - -func (db *DB) Prepare(query string) (*sql.Stmt, error) { - return db.DB.Prepare(db.withComment(context.Background(), query)) -} - -func (db *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { - return db.DB.PrepareContext(ctx, db.withComment(ctx, query)) -} - -// ***** Query Functions ***** - -// ***** Commenter Functions ***** - -func (db *DB) withComment(ctx context.Context, query string) string { - var commentsMap = map[string]string{} - query = strings.TrimSpace(query) - - // Sorted alphabetically - if db.options.EnableAction && (ctx.Value(core.Action) != nil) { - commentsMap[core.Action] = ctx.Value(core.Action).(string) - } - - // `driver` information should not be coming from framework. - // So, explicitly adding that here. - if db.options.EnableDBDriver { - commentsMap[core.Driver] = fmt.Sprintf("database/sql:%s", db.driverName) - } - - if db.options.EnableFramework && (ctx.Value(core.Framework) != nil) { - commentsMap[core.Framework] = ctx.Value(core.Framework).(string) - } - - if db.options.EnableRoute && (ctx.Value(core.Route) != nil) { - commentsMap[core.Route] = ctx.Value(core.Route).(string) - } - - if db.options.EnableTraceparent { - carrier := core.ExtractTraceparent(ctx) - if val, ok := carrier["traceparent"]; ok { - commentsMap[core.Traceparent] = val - } - } - - if db.options.EnableApplication { - if !attemptedToAutosetApplication && db.application == "" { - attemptedToAutosetApplication = true - bi, ok := debug.ReadBuildInfo() - if ok { - db.application = bi.Path - } - } - - commentsMap[core.Application] = db.application - } - - var commentsString string = "" - if len(commentsMap) > 0 { // Converts comments map to string and appends it to query - commentsString = fmt.Sprintf("/*%s*/", core.ConvertMapToComment(commentsMap)) - } - - // A semicolon at the end of the SQL statement means the query ends there. - // We need to insert the comment before that to be considered as part of the SQL statemtent. - if query[len(query)-1:] == ";" { - return fmt.Sprintf("%s%s;", strings.TrimSuffix(query, ";"), commentsString) - } - return fmt.Sprintf("%s%s", query, commentsString) -} - -// ***** Commenter Functions ***** diff --git a/go/database/sql/go-sql_test.go b/go/database/sql/go-sql_test.go deleted file mode 100644 index 7398a26d..00000000 --- a/go/database/sql/go-sql_test.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "context" - "net/http" - "regexp" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/google/sqlcommenter/go/core" - httpnet "github.com/google/sqlcommenter/go/net/http" - "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" - sdktrace "go.opentelemetry.io/otel/sdk/trace" -) - -func TestDisabled(t *testing.T) { - mockDB, _, err := sqlmock.New() - if err != nil { - t.Fatalf("MockSQL failed with unexpected error: %s", err) - } - db := DB{DB: mockDB, driverName: "mocksql", options: core.CommenterOptions{}} - query := "SELECT 2" - if got, want := db.withComment(context.Background(), query), query; got != want { - t.Errorf("db.withComment(context.Background(), %q) = %q, want = %q", query, got, want) - } -} - -func TestHTTP_Net(t *testing.T) { - mockDB, _, err := sqlmock.New() - if err != nil { - t.Fatalf("MockSQL failed with unexpected error: %s", err) - } - - db := DB{DB: mockDB, driverName: "mocksql", options: core.CommenterOptions{EnableDBDriver: true, EnableRoute: true, EnableFramework: true, EnableApplication: true, Application: "app"}, application: "app"} - r, err := http.NewRequest("GET", "hello/1", nil) - if err != nil { - t.Errorf("http.NewRequest('GET', 'hello/1', nil) returned unexpected error: %v", err) - } - - ctx := core.ContextInject(r.Context(), httpnet.NewHTTPRequestExtractor(r, nil)) - got := db.withComment(ctx, "Select 1") - want := "Select 1/*application=app,db_driver=database%2Fsql%3Amocksql,framework=net%2Fhttp,route=hello%2F1*/" - if got != want { - t.Errorf("db.withComment(ctx, 'Select 1') got %q, wanted %q", got, want) - } -} - -func TestQueryWithSemicolon(t *testing.T) { - mockDB, _, err := sqlmock.New() - if err != nil { - t.Fatalf("MockSQL failed with unexpected error: %s", err) - } - - db := DB{DB: mockDB, driverName: "mocksql", options: core.CommenterOptions{EnableDBDriver: true}} - got := db.withComment(context.Background(), "Select 1;") - want := "Select 1/*db_driver=database%2Fsql%3Amocksql*/;" - if got != want { - t.Errorf("db.withComment(context.Background(), 'Select 1;') got %q, wanted %q", got, want) - } -} - -func TestOtelIntegration(t *testing.T) { - mockDB, _, err := sqlmock.New() - if err != nil { - t.Fatalf("MockSQL failed with unexpected error: %s", err) - } - - db := DB{DB: mockDB, driverName: "mocksql", options: core.CommenterOptions{EnableTraceparent: true}} - exp, _ := stdouttrace.New(stdouttrace.WithPrettyPrint()) - bsp := sdktrace.NewSimpleSpanProcessor(exp) // You should use batch span processor in prod - tp := sdktrace.NewTracerProvider( - sdktrace.WithSampler(sdktrace.AlwaysSample()), - sdktrace.WithSpanProcessor(bsp), - ) - ctx, _ := tp.Tracer("").Start(context.Background(), "parent-span-name") - - got := db.withComment(ctx, "Select 1;") - wantRegex := "Select 1/\\*traceparent=\\d{1,2}-[a-zA-Z0-9_]{32}-[a-zA-Z0-9_]{16}-\\d{1,2}\\*/;" - r, err := regexp.Compile(wantRegex) - if err != nil { - t.Errorf("regex.Compile() failed with error: %v", err) - } - - if !r.MatchString(got) { - t.Errorf("%q does not match the given regex %q", got, wantRegex) - } -} diff --git a/go/database/sql/go.mod b/go/database/sql/go.mod index 494e8554..93bd486a 100644 --- a/go/database/sql/go.mod +++ b/go/database/sql/go.mod @@ -3,7 +3,7 @@ module github.com/google/sqlcommenter/go/database/sql go 1.19 require ( - github.com/google/sqlcommenter/go/core v0.0.2-beta + github.com/google/sqlcommenter/go/core v0.0.5-beta go.opentelemetry.io/otel/sdk v1.10.0 ) @@ -16,7 +16,7 @@ require ( require ( github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/google/sqlcommenter/go/net/http v0.0.2-beta + github.com/google/sqlcommenter/go/net/http v0.0.3-beta go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.10.0 go.opentelemetry.io/otel/trace v1.11.1 // indirect ) diff --git a/go/database/sql/go.sum b/go/database/sql/go.sum index 4d587756..8c1e11b9 100644 --- a/go/database/sql/go.sum +++ b/go/database/sql/go.sum @@ -11,10 +11,14 @@ github.com/google/sqlcommenter/go/core v0.0.1-beta h1:IVszEHanWVeS7UcmP8C3SHa57C github.com/google/sqlcommenter/go/core v0.0.1-beta/go.mod h1:CZfcqmbIxngExnZ7Se6AsKNVubZhKyi54aeDJZiqTMQ= github.com/google/sqlcommenter/go/core v0.0.2-beta h1:VnX58Jvf1mkI5KveBddZhCm4YtzG9IQErCNdmfXBU1I= github.com/google/sqlcommenter/go/core v0.0.2-beta/go.mod h1:CZfcqmbIxngExnZ7Se6AsKNVubZhKyi54aeDJZiqTMQ= +github.com/google/sqlcommenter/go/core v0.0.5-beta h1:axqYR1zQCCdRBLnwr/j+ckllBSBJ7uaVdsnANuGzCUI= +github.com/google/sqlcommenter/go/core v0.0.5-beta/go.mod h1:GORu2htXRC4xtejBzOa4ct1L20pohP81DFNYKdCJI70= github.com/google/sqlcommenter/go/net/http v0.0.1-beta h1:7XQ6poZv+ZJwwHWQHlesq9IMsRus3G6Z9n10qAkrGqE= github.com/google/sqlcommenter/go/net/http v0.0.1-beta/go.mod h1:tVUqM1YZ/K3eRTdGzeav1GSbw+BXNdTGzSAbLW9CxAc= github.com/google/sqlcommenter/go/net/http v0.0.2-beta h1:hL/nLxgWeM+2A7yKPoqhyJeaqQZI12kbruQ6/IEiErM= github.com/google/sqlcommenter/go/net/http v0.0.2-beta/go.mod h1:1sd6t92iCHaNQc/v5qxTHp+td7KNoD8IIeG4BRetFZo= +github.com/google/sqlcommenter/go/net/http v0.0.3-beta h1:IE/vO3xKddn/2Bq3k+hSy4CxcEuvE1lUiIDYTXjApzA= +github.com/google/sqlcommenter/go/net/http v0.0.3-beta/go.mod h1:duXQQvXZYCX8eQ+XOrlojWF512ltEp1eSKXc/KiS9lg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= go.opentelemetry.io/otel v1.10.0 h1:Y7DTJMR6zs1xkS/upamJYk0SxxN4C9AqRd77jmZnyY4= diff --git a/go/database/sql/gosql.go b/go/database/sql/gosql.go new file mode 100644 index 00000000..525ddd37 --- /dev/null +++ b/go/database/sql/gosql.go @@ -0,0 +1,120 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "database/sql" + "database/sql/driver" + + "github.com/google/sqlcommenter/go/core" +) + +var ( + _ driver.Driver = (*sqlCommenterDriver)(nil) + _ driver.DriverContext = (*sqlCommenterDriver)(nil) + _ driver.Connector = (*sqlCommenterConnector)(nil) +) + +// SQLCommenterDriver returns a driver object that contains SQLCommenter drivers. +type sqlCommenterDriver struct { + driver driver.Driver + options core.CommenterOptions +} + +func newSQLCommenterDriver(dri driver.Driver, options core.CommenterOptions) *sqlCommenterDriver { + return &sqlCommenterDriver{driver: dri, options: options} +} + +func (d *sqlCommenterDriver) Open(name string) (driver.Conn, error) { + rawConn, err := d.driver.Open(name) + if err != nil { + return nil, err + } + return newSQLCommenterConn(rawConn, d.options), nil +} + +func (d *sqlCommenterDriver) OpenConnector(name string) (driver.Connector, error) { + rawConnector, err := d.driver.(driver.DriverContext).OpenConnector(name) + if err != nil { + return nil, err + } + return newConnector(rawConnector, d, d.options), err +} + +type sqlCommenterConnector struct { + driver.Connector + driver *sqlCommenterDriver + options core.CommenterOptions +} + +func newConnector(connector driver.Connector, driver *sqlCommenterDriver, options core.CommenterOptions) *sqlCommenterConnector { + return &sqlCommenterConnector{ + Connector: connector, + driver: driver, + options: options, + } +} + +func (c *sqlCommenterConnector) Connect(ctx context.Context) (connection driver.Conn, err error) { + connection, err = c.Connector.Connect(ctx) + if err != nil { + return nil, err + } + return newSQLCommenterConn(connection, c.options), nil +} + +func (c *sqlCommenterConnector) Driver() driver.Driver { + return c.driver +} + +type dsnConnector struct { + dsn string + driver driver.Driver +} + +func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return t.driver.Open(t.dsn) +} + +func (t dsnConnector) Driver() driver.Driver { + return t.driver +} + +// Open is a wrapper over sql.Open with OTel instrumentation. +func Open(driverName, dataSourceName string, options core.CommenterOptions) (*sql.DB, error) { + // Retrieve the driver implementation we need to wrap with instrumentation + db, err := sql.Open(driverName, "") + if err != nil { + return nil, err + } + d := db.Driver() + if err = db.Close(); err != nil { + return nil, err + } + + options.Tags.DriverName = driverName + sqlCommenterDriver := newSQLCommenterDriver(d, options) + + if _, ok := d.(driver.DriverContext); ok { + connector, err := sqlCommenterDriver.OpenConnector(dataSourceName) + if err != nil { + return nil, err + } + return sql.OpenDB(connector), nil + } + + return sql.OpenDB(dsnConnector{dsn: dataSourceName, driver: sqlCommenterDriver}), nil +} diff --git a/go/database/sql/gosql_test.go b/go/database/sql/gosql_test.go new file mode 100644 index 00000000..511ab7fe --- /dev/null +++ b/go/database/sql/gosql_test.go @@ -0,0 +1,59 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "database/sql/driver" +) + +type mockConn struct { + prepareStmt driver.Stmt + prepareErr error + + closeErr error + + beginTx driver.Tx + beginErr error +} + +func (c *mockConn) Prepare(query string) (driver.Stmt, error) { + if c.prepareErr != nil { + return nil, c.prepareErr + } + return c.prepareStmt, nil +} + +func (c *mockConn) Close() error { + return c.closeErr +} + +func (c *mockConn) Begin() (driver.Tx, error) { + if c.beginErr != nil { + return nil, c.beginErr + } + return c.beginTx, nil +} + +type mockDriver struct { + conn driver.Conn + openError error +} + +func (d *mockDriver) Open(name string) (driver.Conn, error) { + if d.openError != nil { + return nil, d.openError + } + return d.conn, nil +}