Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(bigquery/storage/managedwriter): context refactoring #8275

Merged
merged 9 commits into from Jul 21, 2023
46 changes: 34 additions & 12 deletions bigquery/storage/managedwriter/client.go
Expand Up @@ -45,6 +45,11 @@ type Client struct {
rawClient *storage.BigQueryWriteClient
projectID string

// retained context. primarily used for connection management and the underlying
// client.
ctx context.Context
cancel context.CancelFunc

// cfg retains general settings (custom ClientOptions).
cfg *writerClientConfig

Expand All @@ -66,21 +71,27 @@ func NewClient(ctx context.Context, projectID string, opts ...option.ClientOptio
}
o = append(o, opts...)

rawClient, err := storage.NewBigQueryWriteClient(ctx, o...)
cCtx, cancel := context.WithCancel(ctx)

rawClient, err := storage.NewBigQueryWriteClient(cCtx, o...)
if err != nil {
cancel()
return nil, err
}
rawClient.SetGoogleClientInfo("gccl", internal.Version)

// Handle project autodetection.
projectID, err = detect.ProjectID(ctx, projectID, "", opts...)
if err != nil {
cancel()
return nil, err
}

return &Client{
rawClient: rawClient,
projectID: projectID,
ctx: cCtx,
cancel: cancel,
cfg: newWriterClientConfig(opts...),
pools: make(map[string]*connectionPool),
}, nil
Expand All @@ -103,6 +114,10 @@ func (c *Client) Close() error {
if err := c.rawClient.Close(); err != nil && firstErr == nil {
firstErr = err
}
// Cancel the retained client context.
if c.cancel != nil {
c.cancel()
}
return firstErr
}

Expand All @@ -114,8 +129,11 @@ func (c *Client) NewManagedStream(ctx context.Context, opts ...WriterOption) (*M
}

// createOpenF builds the opener function we need to access the AppendRows bidi stream.
func createOpenF(ctx context.Context, streamFunc streamClientFunc) func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
func createOpenF(ctx context.Context, streamFunc streamClientFunc, routingHeader string) func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
return func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
if routingHeader != "" {
ctx = metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", routingHeader)
}
arc, err := streamFunc(ctx, opts...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -167,11 +185,11 @@ func (c *Client) buildManagedStream(ctx context.Context, streamFunc streamClient
if err != nil {
return nil, err
}
// Add the writer to the pool, and derive context from the pool.
// Add the writer to the pool.
if err := pool.addWriter(writer); err != nil {
return nil, err
}
writer.ctx, writer.cancel = context.WithCancel(pool.ctx)
writer.ctx, writer.cancel = context.WithCancel(ctx)

// Attach any tag keys to the context on the writer, so instrumentation works as expected.
writer.ctx = setupWriterStatContext(writer)
Expand Down Expand Up @@ -218,7 +236,7 @@ func (c *Client) resolvePool(ctx context.Context, settings *streamSettings, stre
}

// No existing pool available, create one for the location and add to shared pools.
pool, err := c.createPool(ctx, loc, streamFunc)
pool, err := c.createPool(loc, streamFunc)
if err != nil {
return nil, err
}
Expand All @@ -227,24 +245,28 @@ func (c *Client) resolvePool(ctx context.Context, settings *streamSettings, stre
}

// createPool builds a connectionPool.
func (c *Client) createPool(ctx context.Context, location string, streamFunc streamClientFunc) (*connectionPool, error) {
cCtx, cancel := context.WithCancel(ctx)
func (c *Client) createPool(location string, streamFunc streamClientFunc) (*connectionPool, error) {
cCtx, cancel := context.WithCancel(c.ctx)

if c.cfg == nil {
cancel()
return nil, fmt.Errorf("missing client config")
}
if location != "" {
// add location header to the retained pool context.
cCtx = metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", fmt.Sprintf("write_location=%s", location))
}

var routingHeader string
/*
* TODO: set once backend respects the new routing header
* if location != "" && c.projectID != "" {
* routingHeader = fmt.Sprintf("write_location=projects/%s/locations/%s", c.projectID, location)
* }
*/

pool := &connectionPool{
id: newUUID(poolIDPrefix),
location: location,
ctx: cCtx,
cancel: cancel,
open: createOpenF(ctx, streamFunc),
open: createOpenF(cCtx, streamFunc, routingHeader),
callOptions: c.cfg.defaultAppendRowsCallOptions,
baseFlowController: newFlowController(c.cfg.defaultInflightRequests, c.cfg.defaultInflightBytes),
}
Expand Down
12 changes: 8 additions & 4 deletions bigquery/storage/managedwriter/client_test.go
Expand Up @@ -55,10 +55,13 @@ func TestTableParentFromStreamName(t *testing.T) {
}

func TestCreatePool_Location(t *testing.T) {
t.Skip("skipping until new write_location is allowed")
c := &Client{
cfg: &writerClientConfig{},
cfg: &writerClientConfig{},
ctx: context.Background(),
projectID: "myproj",
}
pool, err := c.createPool(context.Background(), "foo", nil)
pool, err := c.createPool("foo", nil)
if err != nil {
t.Fatalf("createPool: %v", err)
}
Expand All @@ -72,7 +75,7 @@ func TestCreatePool_Location(t *testing.T) {
}
found := false
for _, v := range vals {
if v == "write_location=foo" {
if v == "write_location=projects/myproj/locations/foo" {
found = true
break
}
Expand Down Expand Up @@ -151,8 +154,9 @@ func TestCreatePool(t *testing.T) {
for _, tc := range testCases {
c := &Client{
cfg: tc.cfg,
ctx: context.Background(),
}
pool, err := c.createPool(context.Background(), "", nil)
pool, err := c.createPool("", nil)
if err != nil {
t.Errorf("case %q: createPool errored unexpectedly: %v", tc.desc, err)
continue
Expand Down
1 change: 1 addition & 0 deletions bigquery/storage/managedwriter/connection.go
Expand Up @@ -151,6 +151,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
return nil, nil, err
}
}

// The channel relationship with its ARC is 1:1. If we get a new ARC, create a new pending
// write channel and fire up the associated receive processor. The channel ensures that
// responses for a connection are processed in the same order that appends were sent.
Expand Down
2 changes: 1 addition & 1 deletion bigquery/storage/managedwriter/connection_test.go
Expand Up @@ -167,7 +167,7 @@ func TestConnectionPool_OpenCallOptionPropagation(t *testing.T) {
t.Fatalf("no options were propagated")
}
return nil, fmt.Errorf("no real client")
}),
}, ""),
callOptions: []gax.CallOption{
gax.WithGRPCOptions(grpc.MaxCallRecvMsgSize(99)),
},
Expand Down
109 changes: 104 additions & 5 deletions bigquery/storage/managedwriter/integration_test.go
Expand Up @@ -1393,22 +1393,21 @@ func testProtoNormalization(ctx context.Context, t *testing.T, mwClient *Client,
}

func TestIntegration_MultiplexWrites(t *testing.T) {
mwClient, bqClient := getTestClients(context.Background(), t,
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
mwClient, bqClient := getTestClients(ctx, t,
WithMultiplexing(),
WithMultiplexPoolLimit(2),
)
defer mwClient.Close()
defer bqClient.Close()

dataset, cleanup, err := setupTestDataset(context.Background(), t, bqClient, "us-east1")
dataset, cleanup, err := setupTestDataset(ctx, t, bqClient, "us-east1")
if err != nil {
t.Fatalf("failed to init test dataset: %v", err)
}
defer cleanup()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

wantWrites := 10

testTables := []struct {
Expand Down Expand Up @@ -1538,3 +1537,103 @@ func TestIntegration_MultiplexWrites(t *testing.T) {
}

}

func TestIntegration_MingledContexts(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
mwClient, bqClient := getTestClients(ctx, t,
WithMultiplexing(),
WithMultiplexPoolLimit(2),
)
defer mwClient.Close()
defer bqClient.Close()

wantLocation := "us-east4"

dataset, cleanup, err := setupTestDataset(ctx, t, bqClient, wantLocation)
if err != nil {
t.Fatalf("failed to init test dataset: %v", err)
}
defer cleanup()

testTable := dataset.Table(tableIDs.New())
if err := testTable.Create(ctx, &bigquery.TableMetadata{Schema: testdata.SimpleMessageSchema}); err != nil {
t.Fatalf("failed to create test table %s: %v", testTable.FullyQualifiedName(), err)
}

m := &testdata.SimpleMessageProto2{}
descriptorProto := protodesc.ToDescriptorProto(m.ProtoReflect().Descriptor())

numWriters := 4
contexts := make([]context.Context, numWriters)
cancels := make([]context.CancelFunc, numWriters)
writers := make([]*ManagedStream, numWriters)
for i := 0; i < numWriters; i++ {
contexts[i], cancels[i] = context.WithCancel(ctx)
ms, err := mwClient.NewManagedStream(contexts[i],
WithDestinationTable(TableParentFromParts(testTable.ProjectID, testTable.DatasetID, testTable.TableID)),
WithType(DefaultStream),
WithSchemaDescriptor(descriptorProto),
)
if err != nil {
t.Fatalf("instantating writer %d failed: %v", i, err)
}
writers[i] = ms
}

sampleRow, err := proto.Marshal(&testdata.SimpleMessageProto2{
Name: proto.String("datafield"),
Value: proto.Int64(1234),
})
if err != nil {
t.Fatalf("failed to generate sample row")
}

for i := 0; i < numWriters; i++ {
res, err := writers[i].AppendRows(contexts[i], [][]byte{sampleRow})
if err != nil {
t.Errorf("initial write on %d failed: %v", i, err)
} else {
if _, err := res.GetResult(contexts[i]); err != nil {
t.Errorf("GetResult initial write %d: %v", i, err)
}
}
}

// cancel the first context
cancels[0]()
// repeat writes on all other writers with the second context
for i := 1; i < numWriters; i++ {
res, err := writers[i].AppendRows(contexts[i], [][]byte{sampleRow})
if err != nil {
t.Errorf("second write on %d failed: %v", i, err)
} else {
if _, err := res.GetResult(contexts[1]); err != nil {
t.Errorf("GetResult err on second write %d: %v", i, err)
}
}
}

// check that writes to the first writer should fail, even with a valid request context.
if _, err := writers[0].AppendRows(contexts[1], [][]byte{sampleRow}); err == nil {
t.Errorf("write succeeded on first writer when it should have failed")
}

// cancel the second context as well, ensure writer created with good context and bad request context fails
cancels[1]()
if _, err := writers[2].AppendRows(contexts[1], [][]byte{sampleRow}); err == nil {
t.Errorf("write succeeded on third writer with a bad request context")
}

// repeat writes on remaining good writers/contexts
for i := 2; i < numWriters; i++ {
res, err := writers[i].AppendRows(contexts[i], [][]byte{sampleRow})
if err != nil {
t.Errorf("second write on %d failed: %v", i, err)
} else {
if _, err := res.GetResult(contexts[i]); err != nil {
t.Errorf("GetResult err on second write %d: %v", i, err)
}
}
}
}
25 changes: 24 additions & 1 deletion bigquery/storage/managedwriter/managed_stream.go
Expand Up @@ -84,7 +84,7 @@ type ManagedStream struct {

// writer state
mu sync.Mutex
ctx context.Context // used solely for stats/instrumentation.
ctx context.Context // used for stats/instrumentation, and to check the writer is live.
cancel context.CancelFunc
err error // retains any terminal error (writer was closed)
}
Expand Down Expand Up @@ -196,6 +196,11 @@ func (ms *ManagedStream) Finalize(ctx context.Context, opts ...gax.CallOption) (
// attached to the pendingWrite.
func (ms *ManagedStream) appendWithRetry(pw *pendingWrite, opts ...gax.CallOption) error {
for {
ms.mu.Lock()
if ms.err != nil {
return ms.err
}
ms.mu.Unlock()
conn, err := ms.pool.selectConn(pw)
if err != nil {
pw.markDone(nil, err)
Expand Down Expand Up @@ -284,6 +289,12 @@ func (ms *ManagedStream) buildRequest(data [][]byte) *storagepb.AppendRowsReques
// The size of a single request must be less than 10 MB in size.
// Requests larger than this return an error, typically `INVALID_ARGUMENT`.
func (ms *ManagedStream) AppendRows(ctx context.Context, data [][]byte, opts ...AppendOption) (*AppendResult, error) {
// before we do anything, ensure the writer isn't closed.
ms.mu.Lock()
if ms.err != nil {
return nil, ms.err
}
ms.mu.Unlock()
// Ensure we build the request and pending write with a consistent schema version.
curSchemaVersion := ms.curDescVersion
req := ms.buildRequest(data)
Expand All @@ -301,6 +312,7 @@ func (ms *ManagedStream) AppendRows(ctx context.Context, data [][]byte, opts ...
select {
case errCh <- ms.appendWithRetry(pw):
case <-ctx.Done():
case <-ms.ctx.Done():
}
close(errCh)
}()
Expand All @@ -313,6 +325,17 @@ func (ms *ManagedStream) AppendRows(ctx context.Context, data [][]byte, opts ...
// This API expresses request idempotency through offset management, so users who care to use offsets
// can deal with the dropped request.
return nil, ctx.Err()
case <-ms.ctx.Done():
// Same as the request context being done, this indicates the writer context expired. For this case,
// we also attempt to close the writer.
ms.mu.Lock()
if ms.err == nil {
ms.err = ms.ctx.Err()
}
ms.mu.Unlock()
ms.Close()
// Don't relock to fetch the writer terminal error, as we've already ensured that the writer is closed.
return nil, ms.err
case appendErr = <-errCh:
if appendErr != nil {
return nil, appendErr
Expand Down