Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(bigquery/storage/managedwriter): context refactoring #8275

Merged
merged 9 commits into from Jul 21, 2023
48 changes: 35 additions & 13 deletions bigquery/storage/managedwriter/client.go
Expand Up @@ -45,6 +45,11 @@ type Client struct {
rawClient *storage.BigQueryWriteClient
projectID string

// retained context. primarily used for connection management and the underlying
// client.
ctx context.Context
cancel context.CancelFunc

// cfg retains general settings (custom ClientOptions).
cfg *writerClientConfig

Expand All @@ -66,21 +71,27 @@ func NewClient(ctx context.Context, projectID string, opts ...option.ClientOptio
}
o = append(o, opts...)

rawClient, err := storage.NewBigQueryWriteClient(ctx, o...)
cCtx, cancel := context.WithCancel(ctx)

rawClient, err := storage.NewBigQueryWriteClient(cCtx, o...)
if err != nil {
cancel()
return nil, err
}
rawClient.SetGoogleClientInfo("gccl", internal.Version)

// Handle project autodetection.
projectID, err = detect.ProjectID(ctx, projectID, "", opts...)
if err != nil {
cancel()
return nil, err
}

return &Client{
rawClient: rawClient,
projectID: projectID,
ctx: cCtx,
cancel: cancel,
cfg: newWriterClientConfig(opts...),
pools: make(map[string]*connectionPool),
}, nil
Expand All @@ -103,6 +114,10 @@ func (c *Client) Close() error {
if err := c.rawClient.Close(); err != nil && firstErr == nil {
firstErr = err
}
// Cancel the retained client context.
if c.cancel != nil {
c.cancel()
}
return firstErr
}

Expand All @@ -114,8 +129,11 @@ func (c *Client) NewManagedStream(ctx context.Context, opts ...WriterOption) (*M
}

// createOpenF builds the opener function we need to access the AppendRows bidi stream.
func createOpenF(ctx context.Context, streamFunc streamClientFunc) func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
return func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
func createOpenF(streamFunc streamClientFunc, routingHeader string) func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
return func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
if routingHeader != "" {
ctx = metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", routingHeader)
}
arc, err := streamFunc(ctx, opts...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -167,11 +185,11 @@ func (c *Client) buildManagedStream(ctx context.Context, streamFunc streamClient
if err != nil {
return nil, err
}
// Add the writer to the pool, and derive context from the pool.
// Add the writer to the pool.
if err := pool.addWriter(writer); err != nil {
return nil, err
}
writer.ctx, writer.cancel = context.WithCancel(pool.ctx)
writer.ctx, writer.cancel = context.WithCancel(ctx)

// Attach any tag keys to the context on the writer, so instrumentation works as expected.
writer.ctx = setupWriterStatContext(writer)
Expand Down Expand Up @@ -218,7 +236,7 @@ func (c *Client) resolvePool(ctx context.Context, settings *streamSettings, stre
}

// No existing pool available, create one for the location and add to shared pools.
pool, err := c.createPool(ctx, loc, streamFunc)
pool, err := c.createPool(loc, streamFunc)
if err != nil {
return nil, err
}
Expand All @@ -227,24 +245,28 @@ func (c *Client) resolvePool(ctx context.Context, settings *streamSettings, stre
}

// createPool builds a connectionPool.
func (c *Client) createPool(ctx context.Context, location string, streamFunc streamClientFunc) (*connectionPool, error) {
cCtx, cancel := context.WithCancel(ctx)
func (c *Client) createPool(location string, streamFunc streamClientFunc) (*connectionPool, error) {
cCtx, cancel := context.WithCancel(c.ctx)

if c.cfg == nil {
cancel()
return nil, fmt.Errorf("missing client config")
}
if location != "" {
// add location header to the retained pool context.
cCtx = metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", fmt.Sprintf("write_location=%s", location))
}

var routingHeader string
/*
* TODO: set once backend respects the new routing header
* if location != "" && c.projectID != "" {
* routingHeader = fmt.Sprintf("write_location=projects/%s/locations/%s", c.projectID, location)
* }
*/

pool := &connectionPool{
id: newUUID(poolIDPrefix),
location: location,
ctx: cCtx,
cancel: cancel,
open: createOpenF(ctx, streamFunc),
open: createOpenF(streamFunc, routingHeader),
callOptions: c.cfg.defaultAppendRowsCallOptions,
baseFlowController: newFlowController(c.cfg.defaultInflightRequests, c.cfg.defaultInflightBytes),
}
Expand Down
12 changes: 8 additions & 4 deletions bigquery/storage/managedwriter/client_test.go
Expand Up @@ -55,10 +55,13 @@ func TestTableParentFromStreamName(t *testing.T) {
}

func TestCreatePool_Location(t *testing.T) {
t.Skip("skipping until new write_location is allowed")
c := &Client{
cfg: &writerClientConfig{},
cfg: &writerClientConfig{},
ctx: context.Background(),
projectID: "myproj",
}
pool, err := c.createPool(context.Background(), "foo", nil)
pool, err := c.createPool("foo", nil)
if err != nil {
t.Fatalf("createPool: %v", err)
}
Expand All @@ -72,7 +75,7 @@ func TestCreatePool_Location(t *testing.T) {
}
found := false
for _, v := range vals {
if v == "write_location=foo" {
if v == "write_location=projects/myproj/locations/foo" {
found = true
break
}
Expand Down Expand Up @@ -151,8 +154,9 @@ func TestCreatePool(t *testing.T) {
for _, tc := range testCases {
c := &Client{
cfg: tc.cfg,
ctx: context.Background(),
}
pool, err := c.createPool(context.Background(), "", nil)
pool, err := c.createPool("", nil)
if err != nil {
t.Errorf("case %q: createPool errored unexpectedly: %v", tc.desc, err)
continue
Expand Down
23 changes: 14 additions & 9 deletions bigquery/storage/managedwriter/connection.go
Expand Up @@ -54,7 +54,7 @@ type connectionPool struct {

// We centralize the open function on the pool, rather than having an instance of the open func on every
// connection. Opening the connection is a stateless operation.
open func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error)
open func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error)

// We specify default calloptions for the pool.
// Explicit connections may have their own calloptions as well.
Expand Down Expand Up @@ -137,7 +137,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
r := &unaryRetryer{}
for {
recordStat(cp.ctx, AppendClientOpenCount, 1)
arc, err := cp.open(cp.mergeCallOptions(co)...)
arc, err := cp.open(co.ctx, cp.mergeCallOptions(co)...)
if err != nil {
bo, shouldRetry := r.Retry(err)
if shouldRetry {
Expand All @@ -151,6 +151,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
return nil, nil, err
}
}

// The channel relationship with its ARC is 1:1. If we get a new ARC, create a new pending
// write channel and fire up the associated receive processor. The channel ensures that
// responses for a connection are processed in the same order that appends were sent.
Expand All @@ -159,7 +160,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
depth = d
}
ch := make(chan *pendingWrite, depth)
go connRecvProcessor(co, arc, ch)
go connRecvProcessor(co.ctx, co, arc, ch)
return arc, ch, nil
}
}
Expand Down Expand Up @@ -441,13 +442,17 @@ func (co *connection) getStream(arc *storagepb.BigQueryWrite_AppendRowsClient, f
if arc != co.arc && !forceReconnect {
return co.arc, co.pending, nil
}
// We need to (re)open a connection. Cleanup previous connection and channel if they are present.
// We need to (re)open a connection. Cleanup previous connection, channel, and context if they are present.
if co.arc != nil && (*co.arc) != (storagepb.BigQueryWrite_AppendRowsClient)(nil) {
(*co.arc).CloseSend()
}
if co.pending != nil {
close(co.pending)
}
if co.cancel != nil {
co.cancel()
co.ctx, co.cancel = context.WithCancel(co.pool.ctx)
}

co.arc = new(storagepb.BigQueryWrite_AppendRowsClient)
// We're going to (re)open the connection, so clear any optimizer state.
Expand All @@ -464,10 +469,10 @@ type streamClientFunc func(context.Context, ...gax.CallOption) (storagepb.BigQue
// connRecvProcessor is used to propagate append responses back up with the originating write requests. It
// It runs as a goroutine. A connection object allows for reconnection, and each reconnection establishes a new
// processing gorouting and backing channel.
func connRecvProcessor(co *connection, arc storagepb.BigQueryWrite_AppendRowsClient, ch <-chan *pendingWrite) {
func connRecvProcessor(ctx context.Context, co *connection, arc storagepb.BigQueryWrite_AppendRowsClient, ch <-chan *pendingWrite) {
for {
select {
case <-co.ctx.Done():
case <-ctx.Done():
// Context is done, so we're not going to get further updates. Mark all work left in the channel
// with the context error. We don't attempt to re-enqueue in this case.
for {
Expand All @@ -478,7 +483,7 @@ func connRecvProcessor(co *connection, arc storagepb.BigQueryWrite_AppendRowsCli
// It's unlikely this connection will recover here, but for correctness keep the flow controller
// state correct by releasing.
co.release(pw)
pw.markDone(nil, co.ctx.Err())
pw.markDone(nil, ctx.Err())
}
case nextWrite, ok := <-ch:
if !ok {
Expand All @@ -493,12 +498,12 @@ func connRecvProcessor(co *connection, arc storagepb.BigQueryWrite_AppendRowsCli
continue
}
// Record that we did in fact get a response from the backend.
recordStat(co.ctx, AppendResponses, 1)
recordStat(ctx, AppendResponses, 1)

if status := resp.GetError(); status != nil {
// The response from the backend embedded a status error. We record that the error
// occurred, and tag it based on the response code of the status.
if tagCtx, tagErr := tag.New(co.ctx, tag.Insert(keyError, codes.Code(status.GetCode()).String())); tagErr == nil {
if tagCtx, tagErr := tag.New(ctx, tag.Insert(keyError, codes.Code(status.GetCode()).String())); tagErr == nil {
recordStat(tagCtx, AppendResponseErrors, 1)
}
respErr := grpcstatus.ErrorProto(status)
Expand Down
6 changes: 3 additions & 3 deletions bigquery/storage/managedwriter/connection_test.go
Expand Up @@ -61,7 +61,7 @@ func TestConnection_OpenWithRetry(t *testing.T) {
for _, tc := range testCases {
pool := &connectionPool{
ctx: context.Background(),
open: func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
open: func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
if len(tc.errors) == 0 {
panic("out of errors")
}
Expand Down Expand Up @@ -162,12 +162,12 @@ func TestConnectionPool_OpenCallOptionPropagation(t *testing.T) {
pool := &connectionPool{
ctx: ctx,
cancel: cancel,
open: createOpenF(ctx, func(ctx context.Context, opts ...gax.CallOption) (storage.BigQueryWrite_AppendRowsClient, error) {
open: createOpenF(func(ctx context.Context, opts ...gax.CallOption) (storage.BigQueryWrite_AppendRowsClient, error) {
if len(opts) == 0 {
t.Fatalf("no options were propagated")
}
return nil, fmt.Errorf("no real client")
}),
}, ""),
callOptions: []gax.CallOption{
gax.WithGRPCOptions(grpc.MaxCallRecvMsgSize(99)),
},
Expand Down