diff --git a/driver.go b/driver.go index 75deb480..dee75315 100644 --- a/driver.go +++ b/driver.go @@ -25,7 +25,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "cloud.google.com/go/civil" @@ -156,6 +155,9 @@ type connector struct { dsn string connectorConfig connectorConfig + closerMu sync.RWMutex + closed bool + // spannerClientConfig represents the optional advanced configuration to be used // by the Google Cloud Spanner client. spannerClientConfig spanner.ClientConfig @@ -169,7 +171,7 @@ type connector struct { // propagated to the caller. This option is enabled by default. retryAbortsInternally bool - initClient sync.Once + initClient sync.Mutex client *spanner.Client clientErr error adminClient *adminapi.DatabaseAdminClient @@ -264,6 +266,7 @@ func newConnector(d *Driver, dsn string) (*connector, error) { } } config.UserAgent = userAgent + c := &connector{ driver: d, dsn: dsn, @@ -277,6 +280,11 @@ func newConnector(d *Driver, dsn string) (*connector, error) { } func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + c.closerMu.RLock() + defer c.closerMu.RUnlock() + if c.closed { + return nil, fmt.Errorf("connector has been closed") + } return openDriverConn(ctx, c) } @@ -288,17 +296,10 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { c.connectorConfig.instance, c.connectorConfig.database) - c.initClient.Do(func() { - c.client, c.clientErr = spanner.NewClientWithConfig(ctx, databaseName, c.spannerClientConfig, opts...) - c.adminClient, c.adminClientErr = adminapi.NewDatabaseAdminClient(ctx, opts...) - }) - if c.clientErr != nil { - return nil, c.clientErr - } - if c.adminClientErr != nil { - return nil, c.adminClientErr + if err := c.increaseConnCount(ctx, databaseName, opts); err != nil { + return nil, err } - atomic.AddInt32(&c.connCount, 1) + return &conn{ connector: c, client: c.client, @@ -311,10 +312,80 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { }, nil } +// increaseConnCount initializes the client and increases the number of connections that are active. +func (c *connector) increaseConnCount(ctx context.Context, databaseName string, opts []option.ClientOption) error { + c.initClient.Lock() + defer c.initClient.Unlock() + + if c.clientErr != nil { + return c.clientErr + } + if c.adminClientErr != nil { + return c.adminClientErr + } + + if c.client == nil { + c.client, c.clientErr = spanner.NewClientWithConfig(ctx, databaseName, c.spannerClientConfig, opts...) + if c.clientErr != nil { + return c.clientErr + } + + c.adminClient, c.adminClientErr = adminapi.NewDatabaseAdminClient(ctx, opts...) + if c.adminClientErr != nil { + c.client = nil + c.client.Close() + c.adminClient = nil + return c.adminClientErr + } + } + + c.connCount++ + return nil +} + +// decreaseConnCount decreases the number of connections that are active and closes the underlying clients if it was the +// last connection. +func (c *connector) decreaseConnCount() error { + c.initClient.Lock() + defer c.initClient.Unlock() + + c.connCount-- + if c.connCount > 0 { + return nil + } + + return c.closeClients() +} + func (c *connector) Driver() driver.Driver { return c.driver } +func (c *connector) Close() error { + c.closerMu.Lock() + c.closed = true + c.closerMu.Unlock() + + c.driver.mu.Lock() + delete(c.driver.connectors, c.dsn) + c.driver.mu.Unlock() + + return c.closeClients() +} + +// Closes the underlying clients. +func (c *connector) closeClients() (err error) { + if c.client != nil { + c.client.Close() + c.client = nil + } + if c.adminClient != nil { + err = c.adminClient.Close() + c.adminClient = nil + } + return err +} + // SpannerConn is the public interface for the raw Spanner connection for the // sql driver. This interface can be used with the db.Conn().Raw() method. type SpannerConn interface { @@ -954,18 +1025,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } func (c *conn) Close() error { - // Check if this is the last open connection of the connector. - if count := atomic.AddInt32(&c.connector.connCount, -1); count > 0 { - return nil - } - - // This was the last connection. Remove the connector and close the Spanner clients. - c.connector.driver.mu.Lock() - delete(c.connector.driver.connectors, c.connector.dsn) - c.connector.driver.mu.Unlock() - - c.client.Close() - return c.adminClient.Close() + return c.connector.decreaseConnCount() } func (c *conn) Begin() (driver.Tx, error) { diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 901c1548..81e1b0d3 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -2372,7 +2372,118 @@ func TestExcludeTxnFromChangeStreams_Transaction(t *testing.T) { if g, w := exclude, false; g != w { t.Fatalf("exclude_txn_from_change_streams mismatch\n Got: %v\nWant: %v", g, w) } +} + +func TestMaxIdleConnectionsNonZero(t *testing.T) { + t.Parallel() + + // Set MinSessions=1, so we can use the number of BatchCreateSessions requests as an indication + // of the number of clients that was created. + db, server, teardown := setupTestDBConnectionWithParams(t, "MinSessions=1") + defer teardown() + + db.SetMaxIdleConns(2) + for i := 0; i < 2; i++ { + openAndCloseConn(t, db) + } + + // Verify that only one client was created. + // This happens because we have a non-zero value for the number of idle connections. + requests := drainRequestsFromServer(server.TestSpanner) + batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BatchCreateSessionsRequest{})) + if g, w := len(batchRequests), 1; g != w { + t.Fatalf("BatchCreateSessions requests count mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestMaxIdleConnectionsZero(t *testing.T) { + t.Parallel() + + // Set MinSessions=1, so we can use the number of BatchCreateSessions requests as an indication + // of the number of clients that was created. + db, server, teardown := setupTestDBConnectionWithParams(t, "MinSessions=1") + defer teardown() + + db.SetMaxIdleConns(0) + for i := 0; i < 2; i++ { + openAndCloseConn(t, db) + } + + // Verify that two clients were created and closed. + // This should happen because we do not keep any idle connections open. + requests := drainRequestsFromServer(server.TestSpanner) + batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BatchCreateSessionsRequest{})) + if g, w := len(batchRequests), 2; g != w { + t.Fatalf("BatchCreateSessions requests count mismatch\n Got: %v\nWant: %v", g, w) + } +} +func openAndCloseConn(t *testing.T, db *sql.DB) { + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + defer func() { + err = conn.Close() + if err != nil { + t.Fatalf("failed to close connection: %v", err) + } + }() + + var result int64 + if err := conn.QueryRowContext(ctx, "SELECT 1").Scan(&result); err != nil { + t.Fatalf("failed to select: %v", err) + } + if result != 1 { + t.Fatalf("expected 1 got %v", result) + } +} + +func TestCannotReuseClosedConnector(t *testing.T) { + // Note: This test cannot be parallel, as it inspects the size of the shared + // map of connectors in the driver. There is no guarantee how many connectors + // will be open when the test is running, if there are also other tests running + // in parallel. + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + _ = conn.Close() + connectors := db.Driver().(*Driver).connectors + if g, w := len(connectors), 1; g != w { + t.Fatal("underlying connector has not been created") + } + var connector *connector + for _, v := range connectors { + connector = v + } + if connector.closed { + t.Fatal("connector is closed") + } + + if err := db.Close(); err != nil { + t.Fatalf("failed to close connector: %v", err) + } + _, err = db.Conn(ctx) + if err == nil { + t.Fatal("missing error for getting a connection from a closed connector") + } + if g, w := err.Error(), "sql: database is closed"; g != w { + t.Fatalf("error mismatch for getting a connection from a closed connector\n Got: %v\nWant: %v", g, w) + } + // Verify that the underlying connector also has been closed. + if g, w := len(connectors), 0; g != w { + t.Fatal("underlying connector has not been closed") + } + if !connector.closed { + t.Fatal("connector is not closed") + } } func numeric(v string) big.Rat { diff --git a/examples/ddl-batches/main.go b/examples/ddl-batches/main.go index e32fa5ae..f2d1ce3e 100644 --- a/examples/ddl-batches/main.go +++ b/examples/ddl-batches/main.go @@ -29,9 +29,9 @@ import ( // It is therefore recommended that DDL statements are always executed in batches whenever possible. // // DDL batches can be executed in two ways using the Spanner go sql driver: -// 1. By executing the SQL statements `START BATCH DDL` and `RUN BATCH`. -// 2. By unwrapping the Spanner specific driver interface spannerdriver.Driver and calling the -// spannerdriver.Driver#StartBatchDDL and spannerdriver.Driver#RunBatch methods. +// 1. By executing the SQL statements `START BATCH DDL` and `RUN BATCH`. +// 2. By unwrapping the Spanner specific driver interface spannerdriver.Driver and calling the +// spannerdriver.Driver#StartBatchDDL and spannerdriver.Driver#RunBatch methods. // // This sample shows how to use both possibilities. // diff --git a/examples/dml-batches/main.go b/examples/dml-batches/main.go index cf167e67..73a0a41e 100644 --- a/examples/dml-batches/main.go +++ b/examples/dml-batches/main.go @@ -31,9 +31,9 @@ var createTableStatement = "CREATE TABLE Singers (SingerId INT64, Name STRING(MA // that are needed. // // DML batches can be executed in two ways using the Spanner go sql driver: -// 1. By executing the SQL statements `START BATCH DML` and `RUN BATCH`. -// 2. By unwrapping the Spanner specific driver interface spannerdriver.Driver and calling the -// spannerdriver.Driver#StartBatchDML and spannerdriver.Driver#RunBatch methods. +// 1. By executing the SQL statements `START BATCH DML` and `RUN BATCH`. +// 2. By unwrapping the Spanner specific driver interface spannerdriver.Driver and calling the +// spannerdriver.Driver#StartBatchDML and spannerdriver.Driver#RunBatch methods. // // This sample shows how to use both possibilities. //