Skip to content

Commit

Permalink
feat(spanner): add x-goog-spanner-route-to-leader header to Spanner R…
Browse files Browse the repository at this point in the history
…PC contexts for RW/PDML transactions. (#7500)

* feat: add x-goog-spanner-route-to-leader header to Spanner RPC contexts for RW/PDML transactions.

* incorporate requested changes

* incorporate requested changes

* fix tests
  • Loading branch information
rahul2393 committed Mar 3, 2023
1 parent d382522 commit fcab05f
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 58 deletions.
8 changes: 4 additions & 4 deletions spanner/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (t *BatchReadOnlyTransaction) PartitionReadUsingIndexWithOptions(ctx contex
return nil, err
}
var md metadata.MD
resp, err = client.PartitionRead(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.PartitionReadRequest{
resp, err = client.PartitionRead(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), &sppb.PartitionReadRequest{
Session: sid,
Transaction: ts,
Table: table,
Expand Down Expand Up @@ -202,7 +202,7 @@ func (t *BatchReadOnlyTransaction) partitionQuery(ctx context.Context, statement
Params: params,
ParamTypes: paramTypes,
}
resp, err := client.PartitionQuery(contextWithOutgoingMetadata(ctx, sh.getMetadata()), req, gax.WithGRPCOptions(grpc.Header(&md)))
resp, err := client.PartitionQuery(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md)))

if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "partitionQuery"); err != nil {
Expand Down Expand Up @@ -271,7 +271,7 @@ func (t *BatchReadOnlyTransaction) Cleanup(ctx context.Context) {
sid, client := sh.getID(), sh.getClient()

var md metadata.MD
err := client.DeleteSession(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.DeleteSessionRequest{Name: sid}, gax.WithGRPCOptions(grpc.Header(&md)))
err := client.DeleteSession(contextWithOutgoingMetadata(ctx, sh.getMetadata(), true), &sppb.DeleteSessionRequest{Name: sid}, gax.WithGRPCOptions(grpc.Header(&md)))

if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "Cleanup"); err != nil {
Expand Down Expand Up @@ -356,7 +356,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
}
}
return stream(
contextWithOutgoingMetadata(ctx, sh.getMetadata()),
contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader),
sh.session.logger,
rpc,
t.setTimestamp,
Expand Down
62 changes: 41 additions & 21 deletions spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ const (
// the resource being operated on.
resourcePrefixHeader = "google-cloud-resource-prefix"

// routeToLeaderHeader is the name of the metadata header if RW/PDML
// requests need to route to leader.
routeToLeaderHeader = "x-goog-spanner-route-to-leader"

// numChannels is the default value for NumChannels of client.
numChannels = 4
)
Expand Down Expand Up @@ -83,14 +87,15 @@ func parseDatabaseName(db string) (project, instance, database string, err error
// Client is a client for reading and writing data to a Cloud Spanner database.
// A client is safe to use concurrently, except for its Close method.
type Client struct {
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
ct *commonTags
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
ct *commonTags
disableRouteToLeader bool
}

// DatabaseName returns the full name of a database, e.g.,
Expand Down Expand Up @@ -147,24 +152,33 @@ type ClientConfig struct {
// database by this client.
DatabaseRole string

// DisableRouteToLeader specifies if all the requests of type read-write and PDML
// need to be routed to the leader region.
//
// Default: false
DisableRouteToLeader bool

// Logger is the logger to use for this client. If it is nil, all logging
// will be directed to the standard logger.
Logger *log.Logger
}

func contextWithOutgoingMetadata(ctx context.Context, md metadata.MD) context.Context {
func contextWithOutgoingMetadata(ctx context.Context, md metadata.MD, disableRouteToLeader bool) context.Context {
existing, ok := metadata.FromOutgoingContext(ctx)
if ok {
md = metadata.Join(existing, md)
}
if !disableRouteToLeader {
md = metadata.Join(md, metadata.Pairs(routeToLeaderHeader, "true"))
}
return metadata.NewOutgoingContext(ctx, md)
}

// NewClient creates a client to a database. A valid database name has the
// form projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID. It uses
// a default configuration.
func NewClient(ctx context.Context, database string, opts ...option.ClientOption) (*Client, error) {
return NewClientWithConfig(ctx, database, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, opts...)
return NewClientWithConfig(ctx, database, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig, DisableRouteToLeader: false}, opts...)
}

// NewClientWithConfig creates a client to a database. A valid database name has
Expand Down Expand Up @@ -224,7 +238,7 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf
config.incStep = DefaultSessionPoolConfig.incStep
}
// Create a session client.
sc := newSessionClient(pool, database, config.UserAgent, sessionLabels, config.DatabaseRole, metadata.Pairs(resourcePrefixHeader, database), config.Logger, config.CallOptions)
sc := newSessionClient(pool, database, config.UserAgent, sessionLabels, config.DatabaseRole, config.DisableRouteToLeader, metadata.Pairs(resourcePrefixHeader, database), config.Logger, config.CallOptions)
// Create a session pool.
config.SessionPoolConfig.sessionLabels = sessionLabels
sp, err := newSessionPool(sc, config.SessionPoolConfig)
Expand All @@ -233,14 +247,15 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf
return nil, err
}
c = &Client{
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
ct: getCommonTags(sc),
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
ct: getCommonTags(sc),
disableRouteToLeader: config.DisableRouteToLeader,
}
return c, nil
}
Expand Down Expand Up @@ -303,6 +318,7 @@ func (c *Client) Single() *ReadOnlyTransaction {
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txReadOnly.ro = c.ro
t.txReadOnly.disableRouteToLeader = true
t.txReadOnly.replaceSessionFunc = func(ctx context.Context) error {
if t.sh == nil {
return spannerErrorf(codes.InvalidArgument, "missing session handle on transaction")
Expand Down Expand Up @@ -340,6 +356,7 @@ func (c *Client) ReadOnlyTransaction() *ReadOnlyTransaction {
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txReadOnly.ro = c.ro
t.txReadOnly.disableRouteToLeader = true
t.ct = c.ct
return t
}
Expand Down Expand Up @@ -372,7 +389,7 @@ func (c *Client) BatchReadOnlyTransaction(ctx context.Context, tb TimestampBound
sh = &sessionHandle{session: s}

// Begin transaction.
res, err := sh.getClient().BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.BeginTransactionRequest{
res, err := sh.getClient().BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), true), &sppb.BeginTransactionRequest{
Session: sh.getID(),
Options: &sppb.TransactionOptions{
Mode: &sppb.TransactionOptions_ReadOnly_{
Expand Down Expand Up @@ -405,6 +422,7 @@ func (c *Client) BatchReadOnlyTransaction(ctx context.Context, tb TimestampBound
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txReadOnly.ro = c.ro
t.txReadOnly.disableRouteToLeader = true
t.ct = c.ct
return t, nil
}
Expand Down Expand Up @@ -434,6 +452,7 @@ func (c *Client) BatchReadOnlyTransactionFromID(tid BatchReadOnlyTransactionID)
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txReadOnly.ro = c.ro
t.txReadOnly.disableRouteToLeader = true
t.ct = c.ct
return t
}
Expand Down Expand Up @@ -527,6 +546,7 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txReadOnly.ro = c.ro
t.txReadOnly.disableRouteToLeader = c.disableRouteToLeader
t.txOpts = c.txo.merge(options)
t.ct = c.ct

Expand Down Expand Up @@ -607,7 +627,7 @@ func (c *Client) Apply(ctx context.Context, ms []*Mutation, opts ...ApplyOption)
}, TransactionOptions{CommitPriority: ao.priority, TransactionTag: ao.transactionTag})
return resp.CommitTs, err
}
t := &writeOnlyTransaction{sp: c.idleSessions, commitPriority: ao.priority, transactionTag: ao.transactionTag}
t := &writeOnlyTransaction{sp: c.idleSessions, commitPriority: ao.priority, transactionTag: ao.transactionTag, disableRouteToLeader: c.disableRouteToLeader}
return t.applyAtLeastOnce(ctx, ms...)
}

Expand Down
2 changes: 1 addition & 1 deletion spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.R
return rs
}

func (s *StatementResult) getResultSetWithTransactionSet(selector *spannerpb.TransactionSelector, tx []byte) *StatementResult {
func (s StatementResult) getResultSetWithTransactionSet(selector *spannerpb.TransactionSelector, tx []byte) *StatementResult {
res := &StatementResult{
Type: s.Type,
Err: s.Err,
Expand Down
6 changes: 3 additions & 3 deletions spanner/pdml.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (c *Client) partitionedUpdate(ctx context.Context, statement Statement, opt
// Execute the PDML and retry if the transaction is aborted.
executePdmlWithRetry := func(ctx context.Context) (int64, error) {
for {
count, err := executePdml(ctx, sh, req)
count, err := executePdml(contextWithOutgoingMetadata(ctx, sh.getMetadata(), c.disableRouteToLeader), sh, req)
if err == nil {
return count, nil
}
Expand All @@ -105,7 +105,7 @@ func (c *Client) partitionedUpdate(ctx context.Context, statement Statement, opt
func executePdml(ctx context.Context, sh *sessionHandle, req *sppb.ExecuteSqlRequest) (count int64, err error) {
var md metadata.MD
// Begin transaction.
res, err := sh.getClient().BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.BeginTransactionRequest{
res, err := sh.getClient().BeginTransaction(ctx, &sppb.BeginTransactionRequest{
Session: sh.getID(),
Options: &sppb.TransactionOptions{
Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}},
Expand All @@ -118,7 +118,7 @@ func executePdml(ctx context.Context, sh *sessionHandle, req *sppb.ExecuteSqlReq
req.Transaction = &sppb.TransactionSelector{
Selector: &sppb.TransactionSelector_Id{Id: res.Id},
}
resultSet, err := sh.getClient().ExecuteSql(contextWithOutgoingMetadata(ctx, sh.getMetadata()), req, gax.WithGRPCOptions(grpc.Header(&md)))
resultSet, err := sh.getClient().ExecuteSql(ctx, req, gax.WithGRPCOptions(grpc.Header(&md)))
if getGFELatencyMetricsFlag() && md != nil && sh.session.pool != nil {
err := captureGFELatencyStats(tag.NewContext(ctx, sh.session.pool.tagMap), md, "executePdml_ExecuteSql")
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions spanner/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func (s *session) ping() error {
defer span.End()

// s.getID is safe even when s is invalid.
_, err := s.client.ExecuteSql(contextWithOutgoingMetadata(ctx, s.md), &sppb.ExecuteSqlRequest{
_, err := s.client.ExecuteSql(contextWithOutgoingMetadata(ctx, s.md, true), &sppb.ExecuteSqlRequest{
Session: s.getID(),
Sql: "SELECT 1",
})
Expand Down Expand Up @@ -352,7 +352,7 @@ func (s *session) destroyWithContext(ctx context.Context, isExpire bool) bool {
func (s *session) delete(ctx context.Context) {
// Ignore the error because even if we fail to explicitly destroy the
// session, it will be eventually garbage collected by Cloud Spanner.
err := s.client.DeleteSession(contextWithOutgoingMetadata(ctx, s.md), &sppb.DeleteSessionRequest{Name: s.getID()})
err := s.client.DeleteSession(contextWithOutgoingMetadata(ctx, s.md, true), &sppb.DeleteSessionRequest{Name: s.getID()})
// Do not log DeadlineExceeded errors when deleting sessions, as these do
// not indicate anything the user can or should act upon.
if err != nil && ErrCode(err) != codes.DeadlineExceeded {
Expand Down
36 changes: 18 additions & 18 deletions spanner/sessionclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ type sessionConsumer interface {
// will ensure that the sessions that are created are evenly distributed over
// all available channels.
type sessionClient struct {
mu sync.Mutex
closed bool
mu sync.Mutex
closed bool
disableRouteToLeader bool

connPool gtransport.ConnPool
database string
Expand All @@ -101,18 +102,19 @@ type sessionClient struct {
}

// newSessionClient creates a session client to use for a database.
func newSessionClient(connPool gtransport.ConnPool, database, userAgent string, sessionLabels map[string]string, databaseRole string, md metadata.MD, logger *log.Logger, callOptions *vkit.CallOptions) *sessionClient {
func newSessionClient(connPool gtransport.ConnPool, database, userAgent string, sessionLabels map[string]string, databaseRole string, disableRouteToLeader bool, md metadata.MD, logger *log.Logger, callOptions *vkit.CallOptions) *sessionClient {
return &sessionClient{
connPool: connPool,
database: database,
userAgent: userAgent,
id: cidGen.nextID(database),
sessionLabels: sessionLabels,
databaseRole: databaseRole,
md: md,
batchTimeout: time.Minute,
logger: logger,
callOptions: callOptions,
connPool: connPool,
database: database,
userAgent: userAgent,
id: cidGen.nextID(database),
sessionLabels: sessionLabels,
databaseRole: databaseRole,
disableRouteToLeader: disableRouteToLeader,
md: md,
batchTimeout: time.Minute,
logger: logger,
callOptions: callOptions,
}
}

Expand All @@ -136,9 +138,9 @@ func (sc *sessionClient) createSession(ctx context.Context) (*session, error) {
if err != nil {
return nil, err
}
ctx = contextWithOutgoingMetadata(ctx, sc.md)

var md metadata.MD
sid, err := client.CreateSession(ctx, &sppb.CreateSessionRequest{
sid, err := client.CreateSession(contextWithOutgoingMetadata(ctx, sc.md, sc.disableRouteToLeader), &sppb.CreateSessionRequest{
Database: sc.database,
Session: &sppb.Session{Labels: sc.sessionLabels, CreatorRole: sc.databaseRole},
}, gax.WithGRPCOptions(grpc.Header(&md)))
Expand Down Expand Up @@ -237,8 +239,6 @@ func (sc *sessionClient) batchCreateSessions(createSessionCount int32, distribut
func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createCount int32, labels map[string]string, md metadata.MD, consumer sessionConsumer) {
ctx, cancel := context.WithTimeout(context.Background(), sc.batchTimeout)
defer cancel()
ctx = contextWithOutgoingMetadata(ctx, sc.md)

ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchCreateSessions")
defer func() { trace.EndSpan(ctx, nil) }()
trace.TracePrintf(ctx, nil, "Creating a batch of %d sessions", createCount)
Expand All @@ -259,7 +259,7 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC
break
}
var mdForGFELatency metadata.MD
response, err := client.BatchCreateSessions(ctx, &sppb.BatchCreateSessionsRequest{
response, err := client.BatchCreateSessions(contextWithOutgoingMetadata(ctx, sc.md, sc.disableRouteToLeader), &sppb.BatchCreateSessionsRequest{
SessionCount: remainingCreateCount,
Database: sc.database,
SessionTemplate: &sppb.Session{Labels: labels, CreatorRole: sc.databaseRole},
Expand Down
Loading

0 comments on commit fcab05f

Please sign in to comment.